diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..27e8fb94966d415ef0aea4a886d481933ca392b8 --- /dev/null +++ b/.clang-format @@ -0,0 +1,40 @@ +BasedOnStyle: Chromium +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^<.*>' + Priority: 1 + - Regex: '^".*"' + Priority: 2 +SortIncludes: true +Language: Cpp +AccessModifierOffset: 2 +AlignAfterOpenBracket: true +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Right +AlignOperands: true +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: None +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: true +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: false +BinPackArguments: false +BinPackParameters: false +BreakBeforeBraces: Attach +BreakBeforeInheritanceComma: false +BreakBeforeTernaryOperators: true +BreakStringLiterals: false +ColumnLimit: 88 +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +IndentCaseLabels: true +IndentWidth: 4 +TabWidth: 4 +UseTab: Never diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..dc41813948358b1c4961863e937f642b2c5bd1c3 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,41 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.JPG filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.bmp filter=lfs diff=lfs merge=lfs -text +*.pdf filter=lfs diff=lfs merge=lfs -text diff --git a/.github/.stale.yml b/.github/.stale.yml new file mode 100644 index 0000000000000000000000000000000000000000..dc90e5a1c3aad4818a813606b52fdecd2fdf6782 --- /dev/null +++ b/.github/.stale.yml @@ -0,0 +1,17 @@ +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 60 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 +# Issues with these labels will never be considered stale +exemptLabels: + - pinned + - security +# Label to use when marking an issue as stale +staleLabel: wontfix +# Comment to post when marking an issue as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you + for your contributions. +# Comment to post when closing a stale issue. Set to `false` to disable +closeComment: false diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000000000000000000000000000000000..036bffc7cbe7f88ef6c4657752a24a73e4bf9a76 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,30 @@ +--- +name: 🐛 Bug report +about: If something isn't working 🔧 +title: "" +labels: bug +assignees: +--- + +## 🐛 Bug Report + + + +## 🔬 How To Reproduce + +Steps to reproduce the behavior: + +1. ... + +### Environment + +- OS: [e.g. Linux / Windows / macOS] +- Python version, get it with: + +```bash +python --version +``` + +## 📎 Additional context + + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..8f2da5489e290e6e55426eaeac2234c61f21f638 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,3 @@ +# Configuration: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository + +blank_issues_enabled: false diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000000000000000000000000000000000..7ce8c1277148183d41908db966ee2f0978c7a02a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,15 @@ +--- +name: 🚀 Feature request +about: Suggest an idea for this project 🏖 +title: "" +labels: enhancement +assignees: +--- + +## 🚀 Feature Request + + + +## 📎 Additional context + + diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000000000000000000000000000000000000..0b624eefe6041c776f1ab4a6aba44d3d1c8cdd83 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,25 @@ +--- +name: ❓ Question +about: Ask a question about this project 🎓 +title: "" +labels: question +assignees: +--- + +## Checklist + + + +- [ ] I've searched the project's [`issues`] + +## ❓ Question + + + +How can I [...]? + +Is it possible to [...]? + +## 📎 Additional context + + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..4dab74cab6b6b173c9a0a15e1d7652f75bc29818 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,7 @@ +## Description + + + +## Related Issue + + diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 0000000000000000000000000000000000000000..fc4b3c3f511d20f0dd255f76f3fea1fe0dbd1ce5 --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,24 @@ +# Release drafter configuration https://github.com/release-drafter/release-drafter#configuration +# Emojis were chosen to match the https://gitmoji.carloscuesta.me/ + +name-template: "v$RESOLVED_VERSION" +tag-template: "v$RESOLVED_VERSION" + +categories: + - title: ":rocket: Features" + labels: [enhancement, feature] + - title: ":wrench: Fixes" + labels: [bug, bugfix, fix] + - title: ":toolbox: Maintenance & Refactor" + labels: [refactor, refactoring, chore] + - title: ":package: Build System & CI/CD & Test" + labels: [build, ci, testing, test] + - title: ":pencil: Documentation" + labels: [documentation] + - title: ":arrow_up: Dependencies updates" + labels: [dependencies] + +template: | + ## What’s Changed + + $CHANGES diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..8d25bdb208d7554ace8acae236a943a311aed12c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,37 @@ +name: CI CPU + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + # runs-on: self-hosted + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + pip install -r requirements.txt + sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y + + - name: Build and install + run: pip install . + + - name: Run tests + # run: python -m pytest + run: python tests/test_basic.py diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000000000000000000000000000000000000..39eca7f82193545f5afe00a3ff841ee93db4967e --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,23 @@ +# This is a format job. Pre-commit has a first-party GitHub action, so we use +# that: https://github.com/pre-commit/action + +name: Format + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +jobs: + pre-commit: + name: Format + runs-on: ubuntu-latest + # runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/pip.yml b/.github/workflows/pip.yml new file mode 100644 index 0000000000000000000000000000000000000000..87fec4fef633d9f13732c4a762eb6f835c40447b --- /dev/null +++ b/.github/workflows/pip.yml @@ -0,0 +1,62 @@ +name: Pip +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +jobs: + build: + strategy: + fail-fast: false + matrix: + platform: [ubuntu-latest] + python-version: ["3.9", "3.10"] + + runs-on: ${{ matrix.platform }} + # runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Upgrade setuptools and wheel + run: | + pip install --upgrade setuptools wheel + + - name: Install dependencies on Ubuntu + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install libopencv-dev -y + + - name: Install dependencies on macOS + if: runner.os == 'macOS' + run: | + brew update + brew install opencv + + - name: Install dependencies on Windows + if: runner.os == 'Windows' + run: | + choco install opencv -y + + - name: Add requirements + run: python -m pip install --upgrade wheel setuptools + + - name: Install Python dependencies + run: | + pip install pytest + pip install -r requirements.txt + sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y + + - name: Build and install + run: pip install . + + - name: Test + run: python -m pytest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000000000000000000000000000000000000..c272ab19fd460c0d2dca3f0d52567a003e86212c --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,95 @@ +name: PyPI Release +on: + release: + types: [published] + +jobs: + build: + strategy: + fail-fast: false + matrix: + platform: [ubuntu-latest] + python-version: ["3.9", "3.10", "3.11"] + + runs-on: ${{ matrix.platform }} + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Upgrade setuptools and wheel + run: | + pip install --upgrade setuptools wheel + + - name: Install dependencies on Ubuntu + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install libopencv-dev -y + + - name: Install dependencies on macOS + if: runner.os == 'macOS' + run: | + brew update + brew install opencv + + - name: Install dependencies on Windows + if: runner.os == 'Windows' + run: | + choco install opencv -y + + - name: Add requirements + run: python -m pip install --upgrade setuptools wheel build + + - name: Install Python dependencies + run: | + pip install pytest + pip install -r requirements.txt + sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y + + - name: Build source distribution + run: | + python -m build --outdir dist/ + ls -lh dist/ + + - name: Upload to GitHub Release + if: matrix.python-version == '3.10' && github.event_name == 'release' + uses: softprops/action-gh-release@v2 + with: + files: dist/*.whl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Archive wheels + if: matrix.python-version == '3.10' && github.event_name == 'release' + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/*.whl + + + pypi-publish: + name: upload release to PyPI + needs: build + runs-on: ubuntu-latest + environment: pypi + permissions: + # IMPORTANT: this permission is mandatory for Trusted Publishing + id-token: write + steps: + # retrieve your distributions here + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + name: dist + path: dist + + - name: List dist directory + run: ls -lh dist/ + + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6615959c6a79433c573a0855f2f8f1b5fb620199 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +build/ + +lib/ +bin/ + +cmake_modules/ +cmake-build-debug/ +.idea/ +.vscode/ +*.pyc +flagged +.ipynb_checkpoints +__pycache__ +Untitled* +experiments +third_party/REKD +hloc/matchers/dedode.py +gradio_cached_examples +*.mp4 +hloc/matchers/quadtree.py +third_party/QuadTreeAttention +desktop.ini +*.egg-info +output.pkl +log.txt +experiments* +gen_example.py +datasets/lines/terrace0.JPG +datasets/lines/terrace1.JPG +datasets/South-Building* +*.pkl +oryx-build-commands.txt +.ruff_cache* +dist +tmp +backup* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a4445ab2065b11dc18e3c820422008c4287b1d1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,88 @@ +# To use: +# +# pre-commit run -a +# +# Or: +# +# pre-commit run --all-files +# +# Or: +# +# pre-commit install # (runs every time you commit in git) +# +# To update this file: +# +# pre-commit autoupdate +# +# See https://github.com/pre-commit/pre-commit + +ci: + autoupdate_commit_msg: "chore: update pre-commit hooks" + autofix_commit_msg: "style: pre-commit fixes" + +repos: +# Standard hooks +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-added-large-files + exclude: ^imcui/third_party/ + - id: check-case-conflict + exclude: ^imcui/third_party/ + - id: check-merge-conflict + exclude: ^imcui/third_party/ + - id: check-symlinks + exclude: ^imcui/third_party/ + - id: check-yaml + exclude: ^imcui/third_party/ + - id: debug-statements + exclude: ^imcui/third_party/ + - id: end-of-file-fixer + exclude: ^imcui/third_party/ + - id: mixed-line-ending + exclude: ^imcui/third_party/ + - id: requirements-txt-fixer + exclude: ^imcui/third_party/ + - id: trailing-whitespace + exclude: ^imcui/third_party/ + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.8.4" + hooks: + - id: ruff + args: ["--fix", "--show-fixes", "--extend-ignore=E402"] + - id: ruff-format + exclude: ^(docs|imcui/third_party/) + +# Checking static types +- repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.14.0" + hooks: + - id: mypy + files: "setup.py" + args: [] + additional_dependencies: [types-setuptools] + exclude: ^imcui/third_party/ +# Changes tabs to spaces +- repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.5 + hooks: + - id: remove-tabs + exclude: ^(docs|imcui/third_party/) + +# CMake formatting +- repo: https://github.com/cheshirekow/cmake-format-precommit + rev: v0.6.13 + hooks: + - id: cmake-format + additional_dependencies: [pyyaml] + types: [file] + files: (\.cmake|CMakeLists.txt)(.in)?$ + exclude: ^imcui/third_party/ + +# Suggested hook if you add a .clang-format file +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v13.0.0 + hooks: + - id: clang-format + exclude: ^imcui/third_party/ diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..6c419c020eb769944237d2b27260de40ac8a7626 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +alpharealcat@gmail.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..09fd374039435f2c7a313d252fea4148e413d3f6 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +# Use an official conda-based Python image as a parent image +FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime +LABEL maintainer vincentqyw +ARG PYTHON_VERSION=3.10.10 + +# Set the working directory to /code +WORKDIR /code + +# Install Git and Git LFS +RUN apt-get update && apt-get install -y git-lfs +RUN git lfs install + +# Clone the Git repository +RUN git clone --recursive https://github.com/Vincentqyw/image-matching-webui.git /code + +RUN conda create -n imw python=${PYTHON_VERSION} +RUN echo "source activate imw" > ~/.bashrc +ENV PATH /opt/conda/envs/imw/bin:$PATH + +# Make RUN commands use the new environment +SHELL ["conda", "run", "-n", "imw", "/bin/bash", "-c"] +RUN pip install --upgrade pip +RUN pip install -r requirements.txt +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y + +# Export port +EXPOSE 7860 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..2f2db59983f1aca800b0a43c2ab260cfc68fa311 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,12 @@ +# logo +include imcui/assets/logo.webp + +recursive-include imcui/ui *.yaml +recursive-include imcui/api *.yaml +recursive-include imcui/third_party *.yaml *.cfg *.yml + +# ui examples +# recursive-include imcui/datasets *.JPG *.jpg *.png + +# model +recursive-include imcui/third_party/SuperGluePretrainedNetwork *.pth diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7abaf502e4022eb38d09a3530c7a7b6c2b525de2 --- /dev/null +++ b/README.md @@ -0,0 +1,194 @@ +[![Contributors][contributors-shield]][contributors-url] +[![Forks][forks-shield]][forks-url] +[![Stargazers][stars-shield]][stars-url] +[![Issues][issues-shield]][issues-url] + +

+


$\color{red}{\textnormal{Image\ Matching\ WebUI}}$ +
Matching Keypoints between two images

+

+ +## Description + +This simple tool efficiently matches image pairs using multiple famous image matching algorithms. The tool features a Graphical User Interface (GUI) designed using [gradio](https://gradio.app/). You can effortlessly select two images and a matching algorithm and obtain a precise matching result. +**Note**: the images source can be either local images or webcam images. + +Try it on + + Open In Studio + + +Here is a demo of the tool: + +https://github.com/Vincentqyw/image-matching-webui/assets/18531182/263534692-c3484d1b-cc00-4fdc-9b31-e5b7af07ecd9 + +The tool currently supports various popular image matching algorithms, namely: +- [x] [MINIMA](https://github.com/LSXI7/MINIMA), ARXIV 2024 +- [x] [XoFTR](https://github.com/OnderT/XoFTR), CVPR 2024 +- [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024 +- [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024 +- [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024 +- [x] [OmniGlue](https://github.com/Vincentqyw/omniglue-onnx), CVPR 2024 +- [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024 +- [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024 +- [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024 +- [ ] [Mickey](https://github.com/nianticlabs/mickey), CVPR 2024 +- [x] [GIM](https://github.com/xuelunshen/gim), ICLR 2024 +- [x] [ALIKED](https://github.com/Shiaoming/ALIKED), ICCV 2023 +- [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023 +- [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023 +- [x] [SFD2](https://github.com/feixue94/sfd2), CVPR 2023 +- [x] [IMP](https://github.com/feixue94/imp-release), CVPR 2023 +- [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023 +- [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023 +- [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023 +- [x] [GlueStick](https://github.com/cvg/GlueStick), ICCV 2023 +- [ ] [ConvMatch](https://github.com/SuhZhang/ConvMatch), AAAI 2023 +- [x] [LoFTR](https://github.com/zju3dv/LoFTR), CVPR 2021 +- [x] [SOLD2](https://github.com/cvg/SOLD2), CVPR 2021 +- [ ] [LineTR](https://github.com/yosungho/LineTR), RA-L 2021 +- [x] [DKM](https://github.com/Parskatt/DKM), CVPR 2023 +- [ ] [NCMNet](https://github.com/xinliu29/NCMNet), CVPR 2023 +- [x] [TopicFM](https://github.com/Vincentqyw/TopicFM), AAAI 2023 +- [x] [AspanFormer](https://github.com/Vincentqyw/ml-aspanformer), ECCV 2022 +- [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022 +- [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022 +- [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022 +- [x] [CoTR](https://github.com/ubc-vision/COTR), ICCV 2021 +- [x] [ALIKE](https://github.com/Shiaoming/ALIKE), TMM 2022 +- [x] [RoRD](https://github.com/UditSinghParihar/RoRD), IROS 2021 +- [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021 +- [x] [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork), CVPRW 2018 +- [x] [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork), CVPR 2020 +- [x] [D2Net](https://github.com/Vincentqyw/d2-net), CVPR 2019 +- [x] [R2D2](https://github.com/naver/r2d2), NeurIPS 2019 +- [x] [DISK](https://github.com/cvlab-epfl/disk), NeurIPS 2020 +- [ ] [Key.Net](https://github.com/axelBarroso/Key.Net), ICCV 2019 +- [ ] [OANet](https://github.com/zjhthu/OANet), ICCV 2019 +- [x] [SOSNet](https://github.com/scape-research/SOSNet), CVPR 2019 +- [x] [HardNet](https://github.com/DagnyT/hardnet), NeurIPS 2017 +- [x] [SIFT](https://docs.opencv.org/4.x/da/df5/tutorial_py_sift_intro.html), IJCV 2004 + +## How to use + +### HuggingFace / Lightning AI + +Just try it on + + Open In Studio + + +or deploy it locally following the instructions below. + +### Requirements + +- [Python 3.9+](https://www.python.org/downloads/) + +#### Install from pip [NEW] + +Update: now support install from [pip](https://pypi.org/project/imcui), just run: + +```bash +pip install imcui +``` + +#### Install from source + +``` bash +git clone --recursive https://github.com/Vincentqyw/image-matching-webui.git +cd image-matching-webui +conda env create -f environment.yaml +conda activate imw +pip install -e . +``` + +or using [docker](https://hub.docker.com/r/vincentqin/image-matching-webui): + +``` bash +docker pull vincentqin/image-matching-webui:latest +docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860 +``` + +### Deploy to Railway + +Deploy to [Railway](https://railway.app/), setting up a `Custom Start Command` in `Deploy` section: + +``` bash +python -m imcui.api.server +``` + +### Run demo +``` bash +python app.py --config ./config/config.yaml +``` +then open http://localhost:7860 in your browser. + +![](assets/gui.jpg) + +### Add your own feature / matcher + +I provide an example to add local feature in [imcui/hloc/extractors/example.py](imcui/hloc/extractors/example.py). Then add feature settings in `confs` in file [imcui/hloc/extract_features.py](imcui/hloc/extract_features.py). Last step is adding some settings to `matcher_zoo` in file [imcui/ui/config.yaml](imcui/ui/config.yaml). + +### Upload models + +IMCUI hosts all models on [Huggingface](https://huggingface.co/Realcat/imcui_checkpoints). You can upload your model to Huggingface and add it to the [Realcat/imcui_checkpoints](https://huggingface.co/Realcat/imcui_checkpoints) repository. + + +## Contributions welcome! + +External contributions are very much welcome. Please follow the [PEP8 style guidelines](https://www.python.org/dev/peps/pep-0008/) using a linter like flake8. This is a non-exhaustive list of features that might be valuable additions: + +- [x] support pip install command +- [x] add [CPU CI](.github/workflows/ci.yml) +- [x] add webcam support +- [x] add [line feature matching](https://github.com/Vincentqyw/LineSegmentsDetection) algorithms +- [x] example to add a new feature extractor / matcher +- [x] ransac to filter outliers +- [ ] add [rotation images](https://github.com/pidahbus/deep-image-orientation-angle-detection) options before matching +- [ ] support export matches to colmap ([#issue 6](https://github.com/Vincentqyw/image-matching-webui/issues/6)) +- [x] add config file to set default parameters +- [x] dynamically load models and reduce GPU overload + +Adding local features / matchers as submodules is very easy. For example, to add the [GlueStick](https://github.com/cvg/GlueStick): + +``` bash +git submodule add https://github.com/cvg/GlueStick.git imcui/third_party/GlueStick +``` + +If remote submodule repositories are updated, don't forget to pull submodules with: + +``` bash +git submodule update --init --recursive # init and download +git submodule update --remote # update +``` + +if you only want to update one submodule, use `git submodule update --remote imcui/third_party/GlueStick`. + +To format code before committing, run: + +```bash +pre-commit run -a # Auto-checks and fixes +``` + +## Contributors + + + + + +## Resources +- [Image Matching: Local Features & Beyond](https://image-matching-workshop.github.io) +- [Long-term Visual Localization](https://www.visuallocalization.net) + +## Acknowledgement + +This code is built based on [Hierarchical-Localization](https://github.com/cvg/Hierarchical-Localization). We express our gratitude to the authors for their valuable source code. + +[contributors-shield]: https://img.shields.io/github/contributors/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[contributors-url]: https://github.com/Vincentqyw/image-matching-webui/graphs/contributors +[forks-shield]: https://img.shields.io/github/forks/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[forks-url]: https://github.com/Vincentqyw/image-matching-webui/network/members +[stars-shield]: https://img.shields.io/github/stars/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[stars-url]: https://github.com/Vincentqyw/image-matching-webui/stargazers +[issues-shield]: https://img.shields.io/github/issues/Vincentqyw/image-matching-webui.svg?style=for-the-badge +[issues-url]: https://github.com/Vincentqyw/image-matching-webui/issues diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9beb1e9bc714791d2bca91019443e72eb9cfd805 --- /dev/null +++ b/app.py @@ -0,0 +1,31 @@ +import argparse +from pathlib import Path +from imcui.ui.app_class import ImageMatchingApp + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--server_name", + type=str, + default="0.0.0.0", + help="server name", + ) + parser.add_argument( + "--server_port", + type=int, + default=7860, + help="server port", + ) + parser.add_argument( + "--config", + type=str, + default=Path(__file__).parent / "config/config.yaml", + help="config file", + ) + args = parser.parse_args() + ImageMatchingApp( + args.server_name, + args.server_port, + config=args.config, + example_data_root=Path("imcui/datasets"), + ).run() diff --git a/assets/demo.gif b/assets/demo.gif new file mode 100644 index 0000000000000000000000000000000000000000..9af81d1ecded321bbd99ac7b84191518d6daf17d --- /dev/null +++ b/assets/demo.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f163c0e2699181897c81c68e01c60fa4289e886a2a40932d53dd529262d3735 +size 8907062 diff --git a/assets/gui.jpg b/assets/gui.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cc91f306d1cbfb261ce4725891e4843618f5f4f6 --- /dev/null +++ b/assets/gui.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a783162639d05631f34e8e3e9a7df682197a76f675265ebbaa639927e08473f7 +size 1669098 diff --git a/assets/logo.webp b/assets/logo.webp new file mode 100644 index 0000000000000000000000000000000000000000..0a799debc1a06cd6e500a8bccd0ddcef7eca0508 Binary files /dev/null and b/assets/logo.webp differ diff --git a/build_docker.sh b/build_docker.sh new file mode 100644 index 0000000000000000000000000000000000000000..a5aea45e6ff5024b71818dea6f4e7cfb0d0ae6c0 --- /dev/null +++ b/build_docker.sh @@ -0,0 +1,3 @@ +docker build -t image-matching-webui:latest . --no-cache +docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest +docker push vincentqin/image-matching-webui:latest diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e1449d99272db6bc5dcf3452fecae87789bdd4b --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,474 @@ +server: + name: "0.0.0.0" + port: 7861 + +defaults: + setting_threshold: 0.1 + max_keypoints: 2000 + keypoint_threshold: 0.05 + enable_ransac: true + ransac_method: CV2_USAC_MAGSAC + ransac_reproj_threshold: 8 + ransac_confidence: 0.999 + ransac_max_iter: 10000 + ransac_num_samples: 4 + match_threshold: 0.2 + setting_geometry: Homography + +matcher_zoo: + minima(loftr): + matcher: minima_loftr + dense: true + info: + name: MINIMA(LoFTR) #dispaly name + source: "ARXIV 2024" + paper: https://arxiv.org/abs/2412.19412 + display: false + minima(RoMa): + matcher: minima_roma + skip_ci: true + dense: true + info: + name: MINIMA(RoMa) #dispaly name + source: "ARXIV 2024" + paper: https://arxiv.org/abs/2412.19412 + display: false + omniglue: + enable: true + matcher: omniglue + dense: true + info: + name: OmniGlue + source: "CVPR 2024" + github: https://github.com/Vincentqyw/omniglue-onnx + paper: https://arxiv.org/abs/2405.12979 + project: https://hwjiang1510.github.io/OmniGlue + display: true + Mast3R: + enable: false + matcher: mast3r + dense: true + info: + name: Mast3R #dispaly name + source: "CVPR 2024" + github: https://github.com/naver/mast3r + paper: https://arxiv.org/abs/2406.09756 + project: https://dust3r.europe.naverlabs.com + display: true + DUSt3R: + # TODO: duster is under development + enable: true + # skip_ci: true + matcher: duster + dense: true + info: + name: DUSt3R #dispaly name + source: "CVPR 2024" + github: https://github.com/naver/dust3r + paper: https://arxiv.org/abs/2312.14132 + project: https://dust3r.europe.naverlabs.com + display: true + GIM(dkm): + enable: true + # skip_ci: true + matcher: gim(dkm) + dense: true + info: + name: GIM(DKM) #dispaly name + source: "ICLR 2024" + github: https://github.com/xuelunshen/gim + paper: https://arxiv.org/abs/2402.11095 + project: https://xuelunshen.com/gim + display: true + RoMa: + matcher: roma + skip_ci: true + dense: true + info: + name: RoMa #dispaly name + source: "CVPR 2024" + github: https://github.com/Parskatt/RoMa + paper: https://arxiv.org/abs/2305.15404 + project: https://parskatt.github.io/RoMa + display: true + dkm: + matcher: dkm + skip_ci: true + dense: true + info: + name: DKM #dispaly name + source: "CVPR 2023" + github: https://github.com/Parskatt/DKM + paper: https://arxiv.org/abs/2202.00667 + project: https://parskatt.github.io/DKM + display: true + loftr: + matcher: loftr + dense: true + info: + name: LoFTR #dispaly name + source: "CVPR 2021" + github: https://github.com/zju3dv/LoFTR + paper: https://arxiv.org/pdf/2104.00680 + project: https://zju3dv.github.io/loftr + display: true + eloftr: + matcher: eloftr + dense: true + info: + name: Efficient LoFTR #dispaly name + source: "CVPR 2024" + github: https://github.com/zju3dv/efficientloftr + paper: https://zju3dv.github.io/efficientloftr/files/EfficientLoFTR.pdf + project: https://zju3dv.github.io/efficientloftr + display: true + xoftr: + matcher: xoftr + dense: true + info: + name: XoFTR #dispaly name + source: "CVPR 2024" + github: https://github.com/OnderT/XoFTR + paper: https://arxiv.org/pdf/2404.09692 + project: null + display: true + cotr: + enable: false + skip_ci: true + matcher: cotr + dense: true + info: + name: CoTR #dispaly name + source: "ICCV 2021" + github: https://github.com/ubc-vision/COTR + paper: https://arxiv.org/abs/2103.14167 + project: null + display: true + topicfm: + matcher: topicfm + dense: true + info: + name: TopicFM #dispaly name + source: "AAAI 2023" + github: https://github.com/TruongKhang/TopicFM + paper: https://arxiv.org/abs/2307.00485 + project: null + display: true + aspanformer: + matcher: aspanformer + dense: true + info: + name: ASpanformer #dispaly name + source: "ECCV 2022" + github: https://github.com/Vincentqyw/ml-aspanformer + paper: https://arxiv.org/abs/2208.14201 + project: null + display: true + xfeat+lightglue: + enable: true + matcher: xfeat_lightglue + dense: true + info: + name: xfeat+lightglue + source: "CVPR 2024" + github: https://github.com/Vincentqyw/omniglue-onnx + paper: https://arxiv.org/abs/2405.12979 + project: https://hwjiang1510.github.io/OmniGlue + display: true + xfeat(sparse): + matcher: NN-mutual + feature: xfeat + dense: false + info: + name: XFeat #dispaly name + source: "CVPR 2024" + github: https://github.com/verlab/accelerated_features + paper: https://arxiv.org/abs/2404.19174 + project: null + display: true + xfeat(dense): + matcher: xfeat_dense + dense: true + info: + name: XFeat #dispaly name + source: "CVPR 2024" + github: https://github.com/verlab/accelerated_features + paper: https://arxiv.org/abs/2404.19174 + project: null + display: false + dedode: + matcher: Dual-Softmax + feature: dedode + dense: false + info: + name: DeDoDe #dispaly name + source: "3DV 2024" + github: https://github.com/Parskatt/DeDoDe + paper: https://arxiv.org/abs/2308.08479 + project: null + display: true + superpoint+superglue: + matcher: superglue + feature: superpoint_max + dense: false + info: + name: SuperGlue #dispaly name + source: "CVPR 2020" + github: https://github.com/magicleap/SuperGluePretrainedNetwork + paper: https://arxiv.org/abs/1911.11763 + project: null + display: true + superpoint+lightglue: + matcher: superpoint-lightglue + feature: superpoint_max + dense: false + info: + name: LightGlue #dispaly name + source: "ICCV 2023" + github: https://github.com/cvg/LightGlue + paper: https://arxiv.org/pdf/2306.13643 + project: null + display: true + disk: + matcher: NN-mutual + feature: disk + dense: false + info: + name: DISK + source: "NeurIPS 2020" + github: https://github.com/cvlab-epfl/disk + paper: https://arxiv.org/abs/2006.13566 + project: null + display: true + disk+dualsoftmax: + matcher: Dual-Softmax + feature: disk + dense: false + info: + name: DISK + source: "NeurIPS 2020" + github: https://github.com/cvlab-epfl/disk + paper: https://arxiv.org/abs/2006.13566 + project: null + display: false + superpoint+dualsoftmax: + matcher: Dual-Softmax + feature: superpoint_max + dense: false + info: + name: SuperPoint + source: "CVPRW 2018" + github: https://github.com/magicleap/SuperPointPretrainedNetwork + paper: https://arxiv.org/abs/1712.07629 + project: null + display: false + sift+lightglue: + matcher: sift-lightglue + feature: sift + dense: false + info: + name: LightGlue #dispaly name + source: "ICCV 2023" + github: https://github.com/cvg/LightGlue + paper: https://arxiv.org/pdf/2306.13643 + project: null + display: true + disk+lightglue: + matcher: disk-lightglue + feature: disk + dense: false + info: + name: LightGlue + source: "ICCV 2023" + github: https://github.com/cvg/LightGlue + paper: https://arxiv.org/pdf/2306.13643 + project: null + display: true + aliked+lightglue: + matcher: aliked-lightglue + feature: aliked-n16 + dense: false + info: + name: ALIKED + source: "ICCV 2023" + github: https://github.com/Shiaoming/ALIKED + paper: https://arxiv.org/pdf/2304.03608.pdf + project: null + display: true + superpoint+mnn: + matcher: NN-mutual + feature: superpoint_max + dense: false + info: + name: SuperPoint #dispaly name + source: "CVPRW 2018" + github: https://github.com/magicleap/SuperPointPretrainedNetwork + paper: https://arxiv.org/abs/1712.07629 + project: null + display: true + sift+sgmnet: + matcher: sgmnet + feature: sift + dense: false + info: + name: SGMNet #dispaly name + source: "ICCV 2021" + github: https://github.com/vdvchen/SGMNet + paper: https://arxiv.org/abs/2108.08771 + project: null + display: true + sosnet: + matcher: NN-mutual + feature: sosnet + dense: false + info: + name: SOSNet #dispaly name + source: "CVPR 2019" + github: https://github.com/scape-research/SOSNet + paper: https://arxiv.org/abs/1904.05019 + project: https://research.scape.io/sosnet + display: true + hardnet: + matcher: NN-mutual + feature: hardnet + dense: false + info: + name: HardNet #dispaly name + source: "NeurIPS 2017" + github: https://github.com/DagnyT/hardnet + paper: https://arxiv.org/abs/1705.10872 + project: null + display: true + d2net: + matcher: NN-mutual + feature: d2net-ss + dense: false + info: + name: D2Net #dispaly name + source: "CVPR 2019" + github: https://github.com/Vincentqyw/d2-net + paper: https://arxiv.org/abs/1905.03561 + project: https://dusmanu.com/publications/d2-net.html + display: true + rord: + matcher: NN-mutual + feature: rord + dense: false + info: + name: RoRD #dispaly name + source: "IROS 2021" + github: https://github.com/UditSinghParihar/RoRD + paper: https://arxiv.org/abs/2103.08573 + project: https://uditsinghparihar.github.io/RoRD + display: true + alike: + matcher: NN-mutual + feature: alike + dense: false + info: + name: ALIKE #dispaly name + source: "TMM 2022" + github: https://github.com/Shiaoming/ALIKE + paper: https://arxiv.org/abs/2112.02906 + project: null + display: true + lanet: + matcher: NN-mutual + feature: lanet + dense: false + info: + name: LANet #dispaly name + source: "ACCV 2022" + github: https://github.com/wangch-g/lanet + paper: https://openaccess.thecvf.com/content/ACCV2022/papers/Wang_Rethinking_Low-level_Features_for_Interest_Point_Detection_and_Description_ACCV_2022_paper.pdf + project: null + display: true + r2d2: + matcher: NN-mutual + feature: r2d2 + dense: false + info: + name: R2D2 #dispaly name + source: "NeurIPS 2019" + github: https://github.com/naver/r2d2 + paper: https://arxiv.org/abs/1906.06195 + project: null + display: true + darkfeat: + matcher: NN-mutual + feature: darkfeat + dense: false + info: + name: DarkFeat #dispaly name + source: "AAAI 2023" + github: https://github.com/THU-LYJ-Lab/DarkFeat + paper: null + project: null + display: true + sift: + matcher: NN-mutual + feature: sift + dense: false + info: + name: SIFT #dispaly name + source: "IJCV 2004" + github: null + paper: https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf + project: null + display: true + gluestick: + enable: false + matcher: gluestick + dense: true + info: + name: GlueStick #dispaly name + source: "ICCV 2023" + github: https://github.com/cvg/GlueStick + paper: https://arxiv.org/abs/2304.02008 + project: https://iago-suarez.com/gluestick + display: true + sold2: + enable: false + matcher: sold2 + dense: true + info: + name: SOLD2 #dispaly name + source: "CVPR 2021" + github: https://github.com/cvg/SOLD2 + paper: https://arxiv.org/abs/2104.03362 + project: null + display: true + + sfd2+imp: + enable: true + matcher: imp + feature: sfd2 + dense: false + info: + name: SFD2+IMP #dispaly name + source: "CVPR 2023" + github: https://github.com/feixue94/imp-release + paper: https://arxiv.org/pdf/2304.14837 + project: https://feixue94.github.io/ + display: true + + sfd2+mnn: + enable: true + matcher: NN-mutual + feature: sfd2 + dense: false + info: + name: SFD2+MNN #dispaly name + source: "CVPR 2023" + github: https://github.com/feixue94/sfd2 + paper: https://arxiv.org/abs/2304.14845 + project: https://feixue94.github.io/ + display: true + +retrieval_zoo: + netvlad: + enable: true + openibl: + enable: true + cosplace: + enable: true diff --git a/docker/build_docker.bat b/docker/build_docker.bat new file mode 100644 index 0000000000000000000000000000000000000000..9f3fc687e1185de2866a1dbe221599549abdbce8 --- /dev/null +++ b/docker/build_docker.bat @@ -0,0 +1,3 @@ +docker build -t image-matching-webui:latest . --no-cache +# docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest +# docker push vincentqin/image-matching-webui:latest diff --git a/docker/run_docker.bat b/docker/run_docker.bat new file mode 100644 index 0000000000000000000000000000000000000000..da7686293c14465f0899c4b022f89fcc03db93b3 --- /dev/null +++ b/docker/run_docker.bat @@ -0,0 +1 @@ +docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860 diff --git a/docker/run_docker.sh b/docker/run_docker.sh new file mode 100644 index 0000000000000000000000000000000000000000..da7686293c14465f0899c4b022f89fcc03db93b3 --- /dev/null +++ b/docker/run_docker.sh @@ -0,0 +1 @@ +docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860 diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aab94e3a4a1e8e4b5292e2a7767c7e916e3b8e2f --- /dev/null +++ b/environment.yaml @@ -0,0 +1,13 @@ +name: imw +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.10.10 + - pytorch-cuda=12.1 + - pytorch=2.4.0 + - pip + - pip: + - -r requirements.txt diff --git a/imcui/__init__.py b/imcui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/api/__init__.py b/imcui/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..251563d95841af0b77ef47f92f0c7d7c78a90bb8 --- /dev/null +++ b/imcui/api/__init__.py @@ -0,0 +1,47 @@ +import base64 +import io +from typing import List + +import numpy as np +from fastapi.exceptions import HTTPException +from PIL import Image +from pydantic import BaseModel + +from ..hloc import logger +from .core import ImageMatchingAPI + + +class ImagesInput(BaseModel): + data: List[str] = [] + max_keypoints: List[int] = [] + timestamps: List[str] = [] + grayscale: bool = False + image_hw: List[List[int]] = [[], []] + feature_type: int = 0 + rotates: List[float] = [] + scales: List[float] = [] + reference_points: List[List[float]] = [] + binarize: bool = False + + +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(io.BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + logger.warning(f"API cannot decode image: {e}") + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + + +def to_base64_nparray(encoding: str) -> np.ndarray: + return np.array(decode_base64_to_image(encoding)).astype("uint8") + + +__all__ = [ + "ImageMatchingAPI", + "ImagesInput", + "decode_base64_to_image", + "to_base64_nparray", +] diff --git a/imcui/api/client.py b/imcui/api/client.py new file mode 100644 index 0000000000000000000000000000000000000000..2e5130bf8452e904d6c999de5dcc932fb216dffa --- /dev/null +++ b/imcui/api/client.py @@ -0,0 +1,232 @@ +import argparse +import base64 +import os +import pickle +import time +from typing import Dict, List + +import cv2 +import numpy as np +import requests + +ENDPOINT = "http://127.0.0.1:8001" +if "REMOTE_URL_RAILWAY" in os.environ: + ENDPOINT = os.environ["REMOTE_URL_RAILWAY"] + +print(f"API ENDPOINT: {ENDPOINT}") + +API_VERSION = f"{ENDPOINT}/version" +API_URL_MATCH = f"{ENDPOINT}/v1/match" +API_URL_EXTRACT = f"{ENDPOINT}/v1/extract" + + +def read_image(path: str) -> str: + """ + Read an image from a file, encode it as a JPEG and then as a base64 string. + + Args: + path (str): The path to the image to read. + + Returns: + str: The base64 encoded image. + """ + # Read the image from the file + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + + # Encode the image as a png, NO COMPRESSION!!! + retval, buffer = cv2.imencode(".png", img) + + # Encode the JPEG as a base64 string + b64img = base64.b64encode(buffer).decode("utf-8") + + return b64img + + +def do_api_requests(url=API_URL_EXTRACT, **kwargs): + """ + Helper function to send an API request to the image matching service. + + Args: + url (str): The URL of the API endpoint to use. Defaults to the + feature extraction endpoint. + **kwargs: Additional keyword arguments to pass to the API. + + Returns: + List[Dict[str, np.ndarray]]: A list of dictionaries containing the + extracted features. The keys are "keypoints", "descriptors", and + "scores", and the values are ndarrays of shape (N, 2), (N, ?), + and (N,), respectively. + """ + # Set up the request body + reqbody = { + # List of image data base64 encoded + "data": [], + # List of maximum number of keypoints to extract from each image + "max_keypoints": [100, 100], + # List of timestamps for each image (not used?) + "timestamps": ["0", "1"], + # Whether to convert the images to grayscale + "grayscale": 0, + # List of image height and width + "image_hw": [[640, 480], [320, 240]], + # Type of feature to extract + "feature_type": 0, + # List of rotation angles for each image + "rotates": [0.0, 0.0], + # List of scale factors for each image + "scales": [1.0, 1.0], + # List of reference points for each image (not used) + "reference_points": [[640, 480], [320, 240]], + # Whether to binarize the descriptors + "binarize": True, + } + # Update the request body with the additional keyword arguments + reqbody.update(kwargs) + try: + # Send the request + r = requests.post(url, json=reqbody) + if r.status_code == 200: + # Return the response + return r.json() + else: + # Print an error message if the response code is not 200 + print(f"Error: Response code {r.status_code} - {r.text}") + except Exception as e: + # Print an error message if an exception occurs + print(f"An error occurred: {e}") + + +def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]: + """ + Send a request to the API to generate a match between two images. + + Args: + path0 (str): The path to the first image. + path1 (str): The path to the second image. + + Returns: + Dict[str, np.ndarray]: A dictionary containing the generated matches. + The keys are "keypoints0", "keypoints1", "matches0", and "matches1", + and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and + (N, 2), respectively. + """ + files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")} + try: + # TODO: replace files with post json + response = requests.post(API_URL_MATCH, files=files) + pred = {} + if response.status_code == 200: + pred = response.json() + for key in list(pred.keys()): + pred[key] = np.array(pred[key]) + else: + print(f"Error: Response code {response.status_code} - {response.text}") + finally: + files["image0"].close() + files["image1"].close() + return pred + + +def send_request_extract( + input_images: str, viz: bool = False +) -> List[Dict[str, np.ndarray]]: + """ + Send a request to the API to extract features from an image. + + Args: + input_images (str): The path to the image. + + Returns: + List[Dict[str, np.ndarray]]: A list of dictionaries containing the + extracted features. The keys are "keypoints", "descriptors", and + "scores", and the values are ndarrays of shape (N, 2), (N, 128), + and (N,), respectively. + """ + image_data = read_image(input_images) + inputs = { + "data": [image_data], + } + response = do_api_requests( + url=API_URL_EXTRACT, + **inputs, + ) + # breakpoint() + # print("Keypoints detected: {}".format(len(response[0]["keypoints"]))) + + # draw matching, debug only + if viz: + from hloc.utils.viz import plot_keypoints + from ui.viz import fig2im, plot_images + + kpts = np.array(response[0]["keypoints_orig"]) + if "image_orig" in response[0].keys(): + img_orig = np.array(["image_orig"]) + + output_keypoints = plot_images([img_orig], titles="titles", dpi=300) + plot_keypoints([kpts]) + output_keypoints = fig2im(output_keypoints) + cv2.imwrite( + "demo_match.jpg", + output_keypoints[:, :, ::-1].copy(), # RGB -> BGR + ) + return response + + +def get_api_version(): + try: + response = requests.get(API_VERSION).json() + print("API VERSION: {}".format(response["version"])) + except Exception as e: + print(f"An error occurred: {e}") + + +if __name__ == "__main__": + from pathlib import Path + + parser = argparse.ArgumentParser( + description="Send text to stable audio server and receive generated audio." + ) + parser.add_argument( + "--image0", + required=False, + help="Path for the file's melody", + default=str( + Path(__file__).parents[1] + / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg" + ), + ) + parser.add_argument( + "--image1", + required=False, + help="Path for the file's melody", + default=str( + Path(__file__).parents[1] + / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg" + ), + ) + args = parser.parse_args() + + # get api version + get_api_version() + + # request match + # for i in range(10): + # t1 = time.time() + # preds = send_request_match(args.image0, args.image1) + # t2 = time.time() + # print( + # "Time cost1: {} seconds, matched: {}".format( + # (t2 - t1), len(preds["mmkeypoints0_orig"]) + # ) + # ) + + # request extract + for i in range(1000): + t1 = time.time() + preds = send_request_extract(args.image0) + t2 = time.time() + print(f"Time cost2: {(t2 - t1)} seconds") + + # dump preds + with open("preds.pkl", "wb") as f: + pickle.dump(preds, f) diff --git a/imcui/api/config/api.yaml b/imcui/api/config/api.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4f4be347bc713ed03ff2252be5b95096bbe9e43 --- /dev/null +++ b/imcui/api/config/api.yaml @@ -0,0 +1,51 @@ +# This file was generated using the `serve build` command on Ray v2.38.0. + +proxy_location: EveryNode +http_options: + host: 0.0.0.0 + port: 8001 + +grpc_options: + port: 9000 + grpc_servicer_functions: [] + +logging_config: + encoding: TEXT + log_level: INFO + logs_dir: null + enable_access_log: true + +applications: +- name: app1 + route_prefix: / + import_path: api.server:service + runtime_env: {} + deployments: + - name: ImageMatchingService + num_replicas: 4 + ray_actor_options: + num_cpus: 2.0 + num_gpus: 1.0 + +api: + feature: + output: feats-superpoint-n4096-rmax1600 + model: + name: superpoint + nms_radius: 3 + max_keypoints: 4096 + keypoint_threshold: 0.005 + preprocessing: + grayscale: True + force_resize: True + resize_max: 1600 + width: 640 + height: 480 + dfactor: 8 + matcher: + output: matches-NN-mutual + model: + name: nearest_neighbor + do_mutual_check: True + match_threshold: 0.2 + dense: False diff --git a/imcui/api/core.py b/imcui/api/core.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d58d08662d853d5fdd29baee5f8a349b61d369 --- /dev/null +++ b/imcui/api/core.py @@ -0,0 +1,308 @@ +# api.py +import warnings +from pathlib import Path +from typing import Any, Dict, Optional + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch + +from ..hloc import extract_features, logger, match_dense, match_features +from ..hloc.utils.viz import add_text, plot_keypoints +from ..ui.utils import filter_matches, get_feature_model, get_model +from ..ui.viz import display_matches, fig2im, plot_images + +warnings.simplefilter("ignore") + + +class ImageMatchingAPI(torch.nn.Module): + default_conf = { + "ransac": { + "enable": True, + "estimator": "poselib", + "geometry": "homography", + "method": "RANSAC", + "reproj_threshold": 3, + "confidence": 0.9999, + "max_iter": 10000, + }, + } + + def __init__( + self, + conf: dict = {}, + device: str = "cpu", + detect_threshold: float = 0.015, + max_keypoints: int = 1024, + match_threshold: float = 0.2, + ) -> None: + """ + Initializes an instance of the ImageMatchingAPI class. + + Args: + conf (dict): A dictionary containing the configuration parameters. + device (str, optional): The device to use for computation. Defaults to "cpu". + detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015. + max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024. + match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2. + + Returns: + None + """ + super().__init__() + self.device = device + self.conf = {**self.default_conf, **conf} + self._updata_config(detect_threshold, max_keypoints, match_threshold) + self._init_models() + if device == "cuda": + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + logger.info(f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB") + logger.info(f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB") + self.pred = None + + def parse_match_config(self, conf): + if conf["dense"]: + return { + **conf, + "matcher": match_dense.confs.get(conf["matcher"]["model"]["name"]), + "dense": True, + } + else: + return { + **conf, + "feature": extract_features.confs.get(conf["feature"]["model"]["name"]), + "matcher": match_features.confs.get(conf["matcher"]["model"]["name"]), + "dense": False, + } + + def _updata_config( + self, + detect_threshold: float = 0.015, + max_keypoints: int = 1024, + match_threshold: float = 0.2, + ): + self.dense = self.conf["dense"] + if self.conf["dense"]: + try: + self.conf["matcher"]["model"]["match_threshold"] = match_threshold + except TypeError as e: + logger.error(e) + else: + self.conf["feature"]["model"]["max_keypoints"] = max_keypoints + self.conf["feature"]["model"]["keypoint_threshold"] = detect_threshold + self.extract_conf = self.conf["feature"] + + self.match_conf = self.conf["matcher"] + + def _init_models(self): + # initialize matcher + self.matcher = get_model(self.match_conf) + # initialize extractor + if self.dense: + self.extractor = None + else: + self.extractor = get_feature_model(self.conf["feature"]) + + def _forward(self, img0, img1): + if self.dense: + pred = match_dense.match_images( + self.matcher, + img0, + img1, + self.match_conf["preprocessing"], + device=self.device, + ) + last_fixed = "{}".format( # noqa: F841 + self.match_conf["model"]["name"] + ) + else: + pred0 = extract_features.extract( + self.extractor, img0, self.extract_conf["preprocessing"] + ) + pred1 = extract_features.extract( + self.extractor, img1, self.extract_conf["preprocessing"] + ) + pred = match_features.match_images(self.matcher, pred0, pred1) + return pred + + def _convert_pred(self, pred): + ret = { + k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v + for k, v in pred.items() + } + ret = { + k: v[0].cpu().detach().numpy() if isinstance(v, list) else v + for k, v in ret.items() + } + return ret + + @torch.inference_mode() + def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]: + """Extract features from a single image. + + Args: + img0 (np.ndarray): image + + Returns: + Dict[str, np.ndarray]: feature dict + """ + + # setting prams + self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512) + self.extractor.conf["keypoint_threshold"] = kwargs.get( + "keypoint_threshold", 0.0 + ) + + pred = extract_features.extract( + self.extractor, img0, self.extract_conf["preprocessing"] + ) + pred = self._convert_pred(pred) + # back to origin scale + s0 = pred["original_size"] / pred["size"] + pred["keypoints_orig"] = ( + match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5 + ) + # TODO: rotate back + binarize = kwargs.get("binarize", False) + if binarize: + assert "descriptors" in pred + pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8) + pred["descriptors"] = pred["descriptors"].T # N x DIM + return pred + + @torch.inference_mode() + def forward( + self, + img0: np.ndarray, + img1: np.ndarray, + ) -> Dict[str, np.ndarray]: + """ + Forward pass of the image matching API. + + Args: + img0: A 3D NumPy array of shape (H, W, C) representing the first image. + Values are in the range [0, 1] and are in RGB mode. + img1: A 3D NumPy array of shape (H, W, C) representing the second image. + Values are in the range [0, 1] and are in RGB mode. + + Returns: + A dictionary containing the following keys: + - image0_orig: The original image 0. + - image1_orig: The original image 1. + - keypoints0_orig: The keypoints detected in image 0. + - keypoints1_orig: The keypoints detected in image 1. + - mkeypoints0_orig: The raw matches between image 0 and image 1. + - mkeypoints1_orig: The raw matches between image 1 and image 0. + - mmkeypoints0_orig: The RANSAC inliers in image 0. + - mmkeypoints1_orig: The RANSAC inliers in image 1. + - mconf: The confidence scores for the raw matches. + - mmconf: The confidence scores for the RANSAC inliers. + """ + # Take as input a pair of images (not a batch) + assert isinstance(img0, np.ndarray) + assert isinstance(img1, np.ndarray) + self.pred = self._forward(img0, img1) + if self.conf["ransac"]["enable"]: + self.pred = self._geometry_check(self.pred) + return self.pred + + def _geometry_check( + self, + pred: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Filter matches using RANSAC. If keypoints are available, filter by keypoints. + If lines are available, filter by lines. If both keypoints and lines are + available, filter by keypoints. + + Args: + pred (Dict[str, Any]): dict of matches, including original keypoints. + See :func:`filter_matches` for the expected keys. + + Returns: + Dict[str, Any]: filtered matches + """ + pred = filter_matches( + pred, + ransac_method=self.conf["ransac"]["method"], + ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"], + ransac_confidence=self.conf["ransac"]["confidence"], + ransac_max_iter=self.conf["ransac"]["max_iter"], + ) + return pred + + def visualize( + self, + log_path: Optional[Path] = None, + ) -> None: + """ + Visualize the matches. + + Args: + log_path (Path, optional): The directory to save the images. Defaults to None. + + Returns: + None + """ + if self.conf["dense"]: + postfix = str(self.conf["matcher"]["model"]["name"]) + else: + postfix = "{}_{}".format( + str(self.conf["feature"]["model"]["name"]), + str(self.conf["matcher"]["model"]["name"]), + ) + titles = [ + "Image 0 - Keypoints", + "Image 1 - Keypoints", + ] + pred: Dict[str, Any] = self.pred + image0: np.ndarray = pred["image0_orig"] + image1: np.ndarray = pred["image1_orig"] + output_keypoints: np.ndarray = plot_images( + [image0, image1], titles=titles, dpi=300 + ) + if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys(): + plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]]) + text: str = ( + f"# keypoints0: {len(pred['keypoints0_orig'])} \n" + + f"# keypoints1: {len(pred['keypoints1_orig'])}" + ) + add_text(0, text, fs=15) + output_keypoints = fig2im(output_keypoints) + # plot images with raw matches + titles = [ + "Image 0 - Raw matched keypoints", + "Image 1 - Raw matched keypoints", + ] + output_matches_raw, num_matches_raw = display_matches( + pred, titles=titles, tag="KPTS_RAW" + ) + # plot images with ransac matches + titles = [ + "Image 0 - Ransac matched keypoints", + "Image 1 - Ransac matched keypoints", + ] + output_matches_ransac, num_matches_ransac = display_matches( + pred, titles=titles, tag="KPTS_RANSAC" + ) + if log_path is not None: + img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png" + img_matches_raw_path: Path = log_path / f"img_matches_raw_{postfix}.png" + img_matches_ransac_path: Path = ( + log_path / f"img_matches_ransac_{postfix}.png" + ) + cv2.imwrite( + str(img_keypoints_path), + output_keypoints[:, :, ::-1].copy(), # RGB -> BGR + ) + cv2.imwrite( + str(img_matches_raw_path), + output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR + ) + cv2.imwrite( + str(img_matches_ransac_path), + output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR + ) + plt.close("all") diff --git a/imcui/api/server.py b/imcui/api/server.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1932a639374fdf0533caa0e68a2c109ce1a07b --- /dev/null +++ b/imcui/api/server.py @@ -0,0 +1,170 @@ +# server.py +import warnings +from pathlib import Path +from typing import Union + +import numpy as np +import ray +import torch +import yaml +from fastapi import FastAPI, File, UploadFile +from fastapi.responses import JSONResponse +from PIL import Image +from ray import serve + +from . import ImagesInput, to_base64_nparray +from .core import ImageMatchingAPI +from ..hloc import DEVICE +from ..ui import get_version + +warnings.simplefilter("ignore") +app = FastAPI() +if ray.is_initialized(): + ray.shutdown() +ray.init( + dashboard_port=8265, + ignore_reinit_error=True, +) +serve.start( + http_options={"host": "0.0.0.0", "port": 8001}, +) + +num_gpus = 1 if torch.cuda.is_available() else 0 + + +@serve.deployment( + num_replicas=4, ray_actor_options={"num_cpus": 2, "num_gpus": num_gpus} +) +@serve.ingress(app) +class ImageMatchingService: + def __init__(self, conf: dict, device: str): + self.conf = conf + self.api = ImageMatchingAPI(conf=conf, device=device) + + @app.get("/") + def root(self): + return "Hello, world!" + + @app.get("/version") + async def version(self): + return {"version": get_version()} + + @app.post("/v1/match") + async def match( + self, image0: UploadFile = File(...), image1: UploadFile = File(...) + ): + """ + Handle the image matching request and return the processed result. + + Args: + image0 (UploadFile): The first image file for matching. + image1 (UploadFile): The second image file for matching. + + Returns: + JSONResponse: A JSON response containing the filtered match results + or an error message in case of failure. + """ + try: + # Load the images from the uploaded files + image0_array = self.load_image(image0) + image1_array = self.load_image(image1) + + # Perform image matching using the API + output = self.api(image0_array, image1_array) + + # Keys to skip in the output + skip_keys = ["image0_orig", "image1_orig"] + + # Postprocess the output to filter unwanted data + pred = self.postprocess(output, skip_keys) + + # Return the filtered prediction as a JSON response + return JSONResponse(content=pred) + except Exception as e: + # Return an error message with status code 500 in case of exception + return JSONResponse(content={"error": str(e)}, status_code=500) + + @app.post("/v1/extract") + async def extract(self, input_info: ImagesInput): + """ + Extract keypoints and descriptors from images. + + Args: + input_info: An object containing the image data and options. + + Returns: + A list of dictionaries containing the keypoints and descriptors. + """ + try: + preds = [] + for i, input_image in enumerate(input_info.data): + # Load the image from the input data + image_array = to_base64_nparray(input_image) + # Extract keypoints and descriptors + output = self.api.extract( + image_array, + max_keypoints=input_info.max_keypoints[i], + binarize=input_info.binarize, + ) + # Do not return the original image and image_orig + # skip_keys = ["image", "image_orig"] + skip_keys = [] + + # Postprocess the output + pred = self.postprocess(output, skip_keys) + preds.append(pred) + # Return the list of extracted features + return JSONResponse(content=preds) + except Exception as e: + # Return an error message if an exception occurs + return JSONResponse(content={"error": str(e)}, status_code=500) + + def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray: + """ + Reads an image from a file path or an UploadFile object. + + Args: + file_path: A file path or an UploadFile object. + + Returns: + A numpy array representing the image. + """ + if isinstance(file_path, str): + file_path = Path(file_path).resolve(strict=False) + else: + file_path = file_path.file + with Image.open(file_path) as img: + image_array = np.array(img) + return image_array + + def postprocess(self, output: dict, skip_keys: list, binarize: bool = True) -> dict: + pred = {} + for key, value in output.items(): + if key in skip_keys: + continue + if isinstance(value, np.ndarray): + pred[key] = value.tolist() + return pred + + def run(self, host: str = "0.0.0.0", port: int = 8001): + import uvicorn + + uvicorn.run(app, host=host, port=port) + + +def read_config(config_path: Path) -> dict: + with open(config_path, "r") as f: + conf = yaml.safe_load(f) + return conf + + +# api server +conf = read_config(Path(__file__).parent / "config/api.yaml") +service = ImageMatchingService.bind(conf=conf["api"], device=DEVICE) +handle = serve.run(service, route_prefix="/") + +# serve run api.server_ray:service + +# build to generate config file +# serve build api.server_ray:service -o api/config/ray.yaml +# serve run api/config/ray.yaml diff --git a/imcui/api/test/CMakeLists.txt b/imcui/api/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1da6c924042e615ebfa51e4e55de1dcaaddeff8b --- /dev/null +++ b/imcui/api/test/CMakeLists.txt @@ -0,0 +1,17 @@ +cmake_minimum_required(VERSION 3.10) +project(imatchui) + +set(OpenCV_DIR /usr/include/opencv4) +find_package(OpenCV REQUIRED) + +find_package(Boost REQUIRED COMPONENTS system) +if(Boost_FOUND) + include_directories(${Boost_INCLUDE_DIRS}) +endif() + +add_executable(client client.cpp) + +target_include_directories(client PRIVATE ${Boost_LIBRARIES} + ${OpenCV_INCLUDE_DIRS}) + +target_link_libraries(client PRIVATE curl jsoncpp b64 ${OpenCV_LIBS}) diff --git a/imcui/api/test/build_and_run.sh b/imcui/api/test/build_and_run.sh new file mode 100644 index 0000000000000000000000000000000000000000..40921bb9b925c67722247df7ab901668d713e888 --- /dev/null +++ b/imcui/api/test/build_and_run.sh @@ -0,0 +1,16 @@ +# g++ main.cpp -I/usr/include/opencv4 -lcurl -ljsoncpp -lb64 -lopencv_core -lopencv_imgcodecs -o main +# sudo apt-get update +# sudo apt-get install libboost-all-dev -y +# sudo apt-get install libcurl4-openssl-dev libjsoncpp-dev libb64-dev libopencv-dev -y + +cd build +cmake .. +make -j12 + +echo " ======== RUN DEMO ========" + +./client + +echo " ======== END DEMO ========" + +cd .. diff --git a/imcui/api/test/client.cpp b/imcui/api/test/client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b970457fbbef8a81e16856d4e5eeadd1e70715c8 --- /dev/null +++ b/imcui/api/test/client.cpp @@ -0,0 +1,81 @@ +#include +#include +#include "helper.h" + +int main() { + std::string img_path = + "../../../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg"; + cv::Mat original_img = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + + if (original_img.empty()) { + throw std::runtime_error("Failed to decode image"); + } + + // Convert the image to Base64 + std::string base64_img = image_to_base64(original_img); + + // Convert the Base64 back to an image + cv::Mat decoded_img = base64_to_image(base64_img); + cv::imwrite("decoded_image.jpg", decoded_img); + cv::imwrite("original_img.jpg", original_img); + + // The images should be identical + if (cv::countNonZero(original_img != decoded_img) != 0) { + std::cerr << "The images are not identical" << std::endl; + return -1; + } else { + std::cout << "The images are identical!" << std::endl; + } + + // construct params + APIParams params{.data = {base64_img}, + .max_keypoints = {100, 100}, + .timestamps = {"0", "1"}, + .grayscale = {0}, + .image_hw = {{480, 640}, {240, 320}}, + .feature_type = 0, + .rotates = {0.0f, 0.0f}, + .scales = {1.0f, 1.0f}, + .reference_points = {{1.23e+2f, 1.2e+1f}, + {5.0e-1f, 3.0e-1f}, + {2.3e+2f, 2.2e+1f}, + {6.0e-1f, 4.0e-1f}}, + .binarize = {1}}; + + KeyPointResults kpts_results; + + // Convert the parameters to JSON + Json::Value jsonData = paramsToJson(params); + std::string url = "http://127.0.0.1:8001/v1/extract"; + Json::StreamWriterBuilder writer; + std::string output = Json::writeString(writer, jsonData); + + CURL* curl; + CURLcode res; + std::string readBuffer; + + curl_global_init(CURL_GLOBAL_DEFAULT); + curl = curl_easy_init(); + if (curl) { + struct curl_slist* hs = NULL; + hs = curl_slist_append(hs, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, output.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + res = curl_easy_perform(curl); + + if (res != CURLE_OK) + fprintf( + stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); + else { + // std::cout << "Response from server: " << readBuffer << std::endl; + kpts_results = decode_response(readBuffer); + } + curl_easy_cleanup(curl); + } + curl_global_cleanup(); + + return 0; +} diff --git a/imcui/api/test/helper.h b/imcui/api/test/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..9cad50bc0e8de2ee1beb2a2d78ff9d791b3d0f21 --- /dev/null +++ b/imcui/api/test/helper.h @@ -0,0 +1,405 @@ + +#include +#include +#include +#include +#include +#include + +// base64 to image +#include +#include +#include + +/// Parameters used in the API +struct APIParams { + /// A list of images, base64 encoded + std::vector data; + + /// The maximum number of keypoints to detect for each image + std::vector max_keypoints; + + /// The timestamps of the images + std::vector timestamps; + + /// Whether to convert the images to grayscale + bool grayscale; + + /// The height and width of each image + std::vector> image_hw; + + /// The type of feature detector to use + int feature_type; + + /// The rotations of the images + std::vector rotates; + + /// The scales of the images + std::vector scales; + + /// The reference points of the images + std::vector> reference_points; + + /// Whether to binarize the descriptors + bool binarize; +}; + +/** + * @brief Contains the results of a keypoint detector. + * + * @details Stores the keypoints and descriptors for each image. + */ +class KeyPointResults { + public: + KeyPointResults() { + } + + /** + * @brief Constructor. + * + * @param kp The keypoints for each image. + */ + KeyPointResults(const std::vector>& kp, + const std::vector& desc) + : keypoints(kp), descriptors(desc) { + } + + /** + * @brief Append keypoints to the result. + * + * @param kpts The keypoints to append. + */ + inline void append_keypoints(std::vector& kpts) { + keypoints.emplace_back(kpts); + } + + /** + * @brief Append descriptors to the result. + * + * @param desc The descriptors to append. + */ + inline void append_descriptors(cv::Mat& desc) { + descriptors.emplace_back(desc); + } + + /** + * @brief Get the keypoints. + * + * @return The keypoints. + */ + inline std::vector> get_keypoints() { + return keypoints; + } + + /** + * @brief Get the descriptors. + * + * @return The descriptors. + */ + inline std::vector get_descriptors() { + return descriptors; + } + + private: + std::vector> keypoints; + std::vector descriptors; + std::vector> scores; +}; + +/** + * @brief Decodes a base64 encoded string. + * + * @param base64 The base64 encoded string to decode. + * @return The decoded string. + */ +std::string base64_decode(const std::string& base64) { + using namespace boost::archive::iterators; + using It = transform_width, 8, 6>; + + // Find the position of the last non-whitespace character + auto end = base64.find_last_not_of(" \t\n\r"); + if (end != std::string::npos) { + // Move one past the last non-whitespace character + end += 1; + } + + // Decode the base64 string and return the result + return std::string(It(base64.begin()), It(base64.begin() + end)); +} + +/** + * @brief Decodes a base64 string into an OpenCV image + * + * @param base64 The base64 encoded string + * @return The decoded OpenCV image + */ +cv::Mat base64_to_image(const std::string& base64) { + // Decode the base64 string + std::string decodedStr = base64_decode(base64); + + // Decode the image + std::vector data(decodedStr.begin(), decodedStr.end()); + cv::Mat img = cv::imdecode(data, cv::IMREAD_GRAYSCALE); + + // Check for errors + if (img.empty()) { + throw std::runtime_error("Failed to decode image"); + } + + return img; +} + +/** + * @brief Encodes an OpenCV image into a base64 string + * + * This function takes an OpenCV image and encodes it into a base64 string. + * The image is first encoded as a PNG image, and then the resulting + * bytes are encoded as a base64 string. + * + * @param img The OpenCV image + * @return The base64 encoded string + * + * @throws std::runtime_error if the image is empty or encoding fails + */ +std::string image_to_base64(cv::Mat& img) { + if (img.empty()) { + throw std::runtime_error("Failed to read image"); + } + + // Encode the image as a PNG + std::vector buf; + if (!cv::imencode(".png", img, buf)) { + throw std::runtime_error("Failed to encode image"); + } + + // Encode the bytes as a base64 string + using namespace boost::archive::iterators; + using It = + base64_from_binary::const_iterator, 6, 8>>; + std::string base64(It(buf.begin()), It(buf.end())); + + // Pad the string with '=' characters to a multiple of 4 bytes + base64.append((3 - buf.size() % 3) % 3, '='); + + return base64; +} + +/** + * @brief Callback function for libcurl to write data to a string + * + * This function is used as a callback for libcurl to write data to a string. + * It takes the contents, size, and nmemb as parameters, and writes the data to + * the string. + * + * @param contents The data to write + * @param size The size of the data + * @param nmemb The number of members in the data + * @param s The string to write the data to + * @return The number of bytes written + */ +size_t WriteCallback(void* contents, size_t size, size_t nmemb, std::string* s) { + size_t newLength = size * nmemb; + try { + // Resize the string to fit the new data + s->resize(s->size() + newLength); + } catch (std::bad_alloc& e) { + // If there's an error allocating memory, return 0 + return 0; + } + + // Copy the data to the string + std::copy(static_cast(contents), + static_cast(contents) + newLength, + s->begin() + s->size() - newLength); + return newLength; +} + +// Helper functions + +/** + * @brief Helper function to convert a type to a Json::Value + * + * This function takes a value of type T and converts it to a Json::Value. + * It is used to simplify the process of converting a type to a Json::Value. + * + * @param val The value to convert + * @return The converted Json::Value + */ +template Json::Value toJson(const T& val) { + return Json::Value(val); +} + +/** + * @brief Converts a vector to a Json::Value + * + * This function takes a vector of type T and converts it to a Json::Value. + * Each element in the vector is appended to the Json::Value array. + * + * @param vec The vector to convert to Json::Value + * @return The Json::Value representing the vector + */ +template Json::Value vectorToJson(const std::vector& vec) { + Json::Value json(Json::arrayValue); + for (const auto& item : vec) { + json.append(item); + } + return json; +} + +/** + * @brief Converts a nested vector to a Json::Value + * + * This function takes a nested vector of type T and converts it to a + * Json::Value. Each sub-vector is converted to a Json::Value array and appended + * to the main Json::Value array. + * + * @param vec The nested vector to convert to Json::Value + * @return The Json::Value representing the nested vector + */ +template +Json::Value nestedVectorToJson(const std::vector>& vec) { + Json::Value json(Json::arrayValue); + for (const auto& subVec : vec) { + json.append(vectorToJson(subVec)); + } + return json; +} + +/** + * @brief Converts the APIParams struct to a Json::Value + * + * This function takes an APIParams struct and converts it to a Json::Value. + * The Json::Value is a JSON object with the following fields: + * - data: a JSON array of base64 encoded images + * - max_keypoints: a JSON array of integers, max number of keypoints for each + * image + * - timestamps: a JSON array of timestamps, one for each image + * - grayscale: a JSON boolean, whether to convert images to grayscale + * - image_hw: a nested JSON array, each sub-array contains the height and width + * of an image + * - feature_type: a JSON integer, the type of feature detector to use + * - rotates: a JSON array of doubles, the rotation of each image + * - scales: a JSON array of doubles, the scale of each image + * - reference_points: a nested JSON array, each sub-array contains the + * reference points of an image + * - binarize: a JSON boolean, whether to binarize the descriptors + * + * @param params The APIParams struct to convert + * @return The Json::Value representing the APIParams struct + */ +Json::Value paramsToJson(const APIParams& params) { + Json::Value json; + json["data"] = vectorToJson(params.data); + json["max_keypoints"] = vectorToJson(params.max_keypoints); + json["timestamps"] = vectorToJson(params.timestamps); + json["grayscale"] = toJson(params.grayscale); + json["image_hw"] = nestedVectorToJson(params.image_hw); + json["feature_type"] = toJson(params.feature_type); + json["rotates"] = vectorToJson(params.rotates); + json["scales"] = vectorToJson(params.scales); + json["reference_points"] = nestedVectorToJson(params.reference_points); + json["binarize"] = toJson(params.binarize); + return json; +} + +template cv::Mat jsonToMat(Json::Value json) { + int rows = json.size(); + int cols = json[0].size(); + + // Create a single array to hold all the data. + std::vector data; + data.reserve(rows * cols); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + data.push_back(static_cast(json[i][j].asInt())); + } + } + + // Create a cv::Mat object that points to the data. + cv::Mat mat(rows, cols, CV_8UC1, + data.data()); // Change the type if necessary. + // cv::Mat mat(cols, rows,CV_8UC1, data.data()); // Change the type if + // necessary. + + return mat; +} + +/** + * @brief Decodes the response of the server and prints the keypoints + * + * This function takes the response of the server, a JSON string, and decodes + * it. It then prints the keypoints and draws them on the original image. + * + * @param response The response of the server + * @return The keypoints and descriptors + */ +KeyPointResults decode_response(const std::string& response, bool viz = true) { + Json::CharReaderBuilder builder; + Json::CharReader* reader = builder.newCharReader(); + + Json::Value jsonData; + std::string errors; + + // Parse the JSON response + bool parsingSuccessful = reader->parse( + response.c_str(), response.c_str() + response.size(), &jsonData, &errors); + delete reader; + + if (!parsingSuccessful) { + // Handle error + std::cout << "Failed to parse the JSON, errors:" << std::endl; + std::cout << errors << std::endl; + return KeyPointResults(); + } + + KeyPointResults kpts_results; + + // Iterate over the images + for (const auto& jsonItem : jsonData) { + auto jkeypoints = jsonItem["keypoints"]; + auto jkeypoints_orig = jsonItem["keypoints_orig"]; + auto jdescriptors = jsonItem["descriptors"]; + auto jscores = jsonItem["scores"]; + auto jimageSize = jsonItem["image_size"]; + auto joriginalSize = jsonItem["original_size"]; + auto jsize = jsonItem["size"]; + + std::vector vkeypoints; + std::vector vscores; + + // Iterate over the keypoints + int counter = 0; + for (const auto& keypoint : jkeypoints_orig) { + if (counter < 10) { + // Print the first 10 keypoints + std::cout << keypoint[0].asFloat() << ", " << keypoint[1].asFloat() + << std::endl; + } + counter++; + // Convert the Json::Value to a cv::KeyPoint + vkeypoints.emplace_back( + cv::KeyPoint(keypoint[0].asFloat(), keypoint[1].asFloat(), 0.0)); + } + + if (viz && jsonItem.isMember("image_orig")) { + auto jimg_orig = jsonItem["image_orig"]; + cv::Mat img = jsonToMat(jimg_orig); + cv::imwrite("viz_image_orig.jpg", img); + + // Draw keypoints on the image + cv::Mat imgWithKeypoints; + cv::drawKeypoints(img, vkeypoints, imgWithKeypoints, cv::Scalar(0, 0, 255)); + + // Write the image with keypoints + std::string filename = "viz_image_orig_keypoints.jpg"; + cv::imwrite(filename, imgWithKeypoints); + } + + // Iterate over the descriptors + cv::Mat descriptors = jsonToMat(jdescriptors); + kpts_results.append_keypoints(vkeypoints); + kpts_results.append_descriptors(descriptors); + } + return kpts_results; +} diff --git a/imcui/assets/logo.webp b/imcui/assets/logo.webp new file mode 100644 index 0000000000000000000000000000000000000000..0a799debc1a06cd6e500a8bccd0ddcef7eca0508 Binary files /dev/null and b/imcui/assets/logo.webp differ diff --git a/imcui/datasets/.gitignore b/imcui/datasets/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/datasets/lines/terrace0.JPG b/imcui/datasets/lines/terrace0.JPG new file mode 100644 index 0000000000000000000000000000000000000000..e3f688c4d14b490da30b57cd1312b144588efe32 --- /dev/null +++ b/imcui/datasets/lines/terrace0.JPG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4198d3c47d8b397f3a40d58e32e516b8e4f9db4e989992dd069b374880412f5 +size 66986 diff --git a/imcui/datasets/lines/terrace1.JPG b/imcui/datasets/lines/terrace1.JPG new file mode 100644 index 0000000000000000000000000000000000000000..4605fcf9bec3ed31c92b0a0f067d5cc16411fc9d --- /dev/null +++ b/imcui/datasets/lines/terrace1.JPG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d94851889de709b8c8a11b2057e93627a21f623534e6ba2b3a1442b233fd7f20 +size 67363 diff --git a/imcui/datasets/sacre_coeur/README.md b/imcui/datasets/sacre_coeur/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d69115f7f262f6d97aa52bed9083bf3374249645 --- /dev/null +++ b/imcui/datasets/sacre_coeur/README.md @@ -0,0 +1,3 @@ +# Sacre Coeur demo + +We provide here a subset of images depicting the Sacre Coeur. These images were obtained from the [Image Matching Challenge 2021](https://www.cs.ubc.ca/research/image-matching-challenge/2021/data/) and were originally collected by the [Yahoo Flickr Creative Commons 100M (YFCC) dataset](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). diff --git a/imcui/datasets/sacre_coeur/mapping/02928139_3448003521.jpg b/imcui/datasets/sacre_coeur/mapping/02928139_3448003521.jpg new file mode 100644 index 0000000000000000000000000000000000000000..102589fa1a501f365fef0051f5ae97c42eb560ff --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/02928139_3448003521.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f52d9dcdb3ba9d8cf025025fb1be3f8f8d1ba0e0d84ab7eeb271215589ca608 +size 518060 diff --git a/imcui/datasets/sacre_coeur/mapping/03903474_1471484089.jpg b/imcui/datasets/sacre_coeur/mapping/03903474_1471484089.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7b44afda94ce4044c6d13df6eb5b5b7218e7d38e --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/03903474_1471484089.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70969e7b24bbc8fe8de0a114d245ab9322df2bf10cd8533cc0d3ec6ec5f59c15 +size 348789 diff --git a/imcui/datasets/sacre_coeur/mapping/10265353_3838484249.jpg b/imcui/datasets/sacre_coeur/mapping/10265353_3838484249.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0c91c5a55b830f456163fd09092a3887a1c37ef2 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/10265353_3838484249.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9c8ad3c8694703c08da1a4ec3980e1a35011690368f05e167092b33be8ff25c +size 454554 diff --git a/imcui/datasets/sacre_coeur/mapping/17295357_9106075285.jpg b/imcui/datasets/sacre_coeur/mapping/17295357_9106075285.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d38e80b2a28c7d06b28cc9a36b97d656b60b912 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/17295357_9106075285.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54dff1885bf44b5c0e0c0ce702220832e99e5b30f38462d1ef5b9d4a0d794f98 +size 535133 diff --git a/imcui/datasets/sacre_coeur/mapping/32809961_8274055477.jpg b/imcui/datasets/sacre_coeur/mapping/32809961_8274055477.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a8ca6a7b7f3dfc851bf3c06c1e469d12503418d1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/32809961_8274055477.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c2bf90ab409dfed11b7c112c120917a51141010335530c8607b71010d3919fa +size 458230 diff --git a/imcui/datasets/sacre_coeur/mapping/44120379_8371960244.jpg b/imcui/datasets/sacre_coeur/mapping/44120379_8371960244.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c2d24e63b8141978e577f9c393d5214afe25e241 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/44120379_8371960244.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc7bbe1e4b47eeebf986e658d44704dbcce902a3af0d4853ef2e540c95a77659 +size 357768 diff --git a/imcui/datasets/sacre_coeur/mapping/51091044_3486849416.jpg b/imcui/datasets/sacre_coeur/mapping/51091044_3486849416.jpg new file mode 100644 index 0000000000000000000000000000000000000000..38d612bf9b595f66c788a3dd7f16ada85f34f99e --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/51091044_3486849416.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe4d46cb3feab25b4d184dca745921d7de7dfd18a0f3343660d8d3de07a0c054 +size 491952 diff --git a/imcui/datasets/sacre_coeur/mapping/60584745_2207571072.jpg b/imcui/datasets/sacre_coeur/mapping/60584745_2207571072.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bfda213d8d27cb0d8840682497d48c88828604b6 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/60584745_2207571072.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:186e6b8a426403e56a608429f7f63d5a56d32a895487f3da7cb9ceaff97f563f +size 470664 diff --git a/imcui/datasets/sacre_coeur/mapping/71295362_4051449754.jpg b/imcui/datasets/sacre_coeur/mapping/71295362_4051449754.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a8c27b023e9ac5a205e3e5e9ae83df437fcf1e44 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/71295362_4051449754.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7e7a59fab6f497bd94b066589a62cb74ceffe63ff3c8cd77b9b6b39bfc66bae +size 368510 diff --git a/imcui/datasets/sacre_coeur/mapping/93341989_396310999.jpg b/imcui/datasets/sacre_coeur/mapping/93341989_396310999.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2bab9cbcccfb317965c25ff0a5892560ac89c0ab --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping/93341989_396310999.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2932ea734fb26d5417aba1b1891760299ddf16facfd6a76a3712a7f11652e1f6 +size 364593 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..065d5ac91a7e84130a452f7455235f6e47939ad1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b73c04d53516237028bd6c74011d2b94eb09a99e2741ee2c491f070a4b9dd28 +size 134199 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5ab52d7abbd300abcc377758aee8fed52be247b0 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85a3bac5d7072d1bb06022b052d4c7b27e7216a8e02688ab5d9d954799254a06 +size 127812 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25e5aedf11797e410e45e55a79fbae55d03acb1d --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3d1ccde193e18620aa6da0aec5ddbbe612f30f2b398cd596e6585b9e114e45f +size 133556 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..43890e27b5281a7c404fe7ff57460cb2825b3517 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4d238d8a052162da641b0d54506c85641c91f6f95cdf471e79147f4c373162d +size 115076 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9fef9b14bb670469d8b95d20e51641fff886c1ba --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2110adbf1d114c498c5d011adfbe779c813e894500ff429a5b365447a3d9d106 +size 134430 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dbcdda46077574aa07505e0af4404b0121efdc53 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbf08e8eadcadeed24a6843cf79ee0edf1771d09c71b7e1d387a997ea1922cfb +size 133104 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..59a633f7b02c3837eef980845ef309dbb882b611 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ede5a1cf1b99a230b407e24e7bf1a7926cf684d889674d085b299f8937ee3ae3 +size 114747 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..40e6ec78de130ca828d21dccc8ec18c5208a7047 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5489c853477a1e0eab0dc7862261e5ff3bca8b18e0dc742fe8be04473e993bb2 +size 82274 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c29068ed9422942a412c604d693b987ada0c148d --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2991086430d5880b01250023617c255dc165e14a00f199706445132ad7f3501e +size 79432 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87e1fd29e80b566665e3c0b0f1a9e5f512a746f1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcca8fbd0b68c41fa987982e56fee055997c09e6e88f7de8d46034a8683c931e +size 81912 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9fa21bc56b688822d8f72529f34be3fde624bb47 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbd874af9c4d1406a5adfa10f9454b1444d550dbe21bd57619df906a66e79571 +size 66472 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..04514db419ca03a60e08b8edb9f74e0eb724b7b9 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6fc41beb78ec8e5adef2e601a344cddfe5fe353b4893e186b6036452f8c8198 +size 82027 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af0f8cbad53e36cbb2efd00b75a78e54d3df07d1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc55ac4f079176709f577564d390bf3e2c4e08f53378270465062d198699c100 +size 81684 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ceddbfe43526b5b904d91d43b1261f5f8950158d --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5ef044c01f3e94868480cab60841caa92a410c33643d9af6b60be14ee37d60f +size 66563 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e0fb1c291af7ed2e45cf25909a17d6196a7927cc --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b5cd34c3b6ff6fed9c32fe94e9311d9fcd94e4b6ed23ff205fca44c570e1827 +size 96685 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e27c62e234992f2e424c826e9fdf93772fe5e411 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:225c06d54c04d2ed14ec6d885a91760e9faaf298e0a42e776603d44759736b30 +size 104189 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a097e96b8d913f91bd03ff69cbf8ba8c168532b5 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9bda2da8868f114170621d93595d1b4600dff8ddda6e8211842b5598ac866ed3 +size 101098 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20c7db4102c7c878ada35792e34fa49b1db177d2 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de7d6445cadee22083af4a99b970674ee2eb4286162ec65e9e134cb29d1b2748 +size 83143 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..72bcfa0bc7488c9e6f2f7fa48414751fbab323aa --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f14f08b0d72e7c3d946a50df958db3aa4058d1f9e5acb3ebb3d39a53503b1126 +size 96754 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6cc5f76d9d7435ff36d504fda63719e84cc801d2 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5ec4c6fa41d84c07c8fffad7f19630fe2ddb88b042e1a80b470a3566320cb77 +size 101953 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..92e85895809de28aee13c9e588c4c4b642234a6c --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:095ca3d56fc0ce8ef5879a4dfcd4f2af000e7df41a602d546f164896e5add1b5 +size 82961 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f5675e6886adc9ab72f023ebaaaaea34addbd50 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fed7e7f7adc94c88b9755fce914bd745855392fdce1057360731cad09adc2c3a +size 119729 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..708cf8448bf95d8d6f5b7bfc2cf8887b90423038 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a8f76fdbfe1223bc2342ec915b129f480425d2e1fd19a794f6e991146199928 +size 125780 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..85c17020b2e6b48021409dce19ec33bb3a63f3b7 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:335726de47840b164cd71f68ad9caa30b496a52783d4669dc5fab45da8e0427f +size 111548 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..22831f09d5d8d80dce6056e67d29c0cf0b5633ca --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffc1822c65765801d9320e0dff5fecc42a013c1d4ba50855314aed0789f1d8f2 +size 87725 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f38d0ce7d34fb720efe81d0b7221278f3a2b2a8 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af42b452bca3f3e5305700f2a977a41286ffcc1ef8d4a992450c3e98cd1c1d05 +size 119644 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..684dd21fbc6acb244f36bdb9b576681214de06e6 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2452d8839721d7c54755be96e75213277de99b9b9edc6c47b3d5bc94583b42c1 +size 111275 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..22bf9bdad97bfa626738eb88110a0ba12a967ed9 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7d580e865a744f52057137ca79848cc173be19aaa6db9dcaa60f9be46f4c465 +size 87490 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4702aed5bddf1b3db6fcf7479c39ec33514ec44a --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55c35e80d3783c5a93a724997ae62613460dcda6a6672593a8e9af6cc40f43c0 +size 98363 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5deeb08577e2b4517b23e1179f81f59b6ac3b1f5 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23d8d3c4dcfea64e6779f8e216377a080abfdd4a3bc0bf554783c8b7f540d27f +size 102149 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d6d066deabecff0cf4c1de14e199ab947244460 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:516d60da5e943828a525407524cf9c2ee96f580afb713be63b5714e98be259b7 +size 92591 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7f21448e6c3b3b7a9c40879d9743e91037738266 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b0cf682dbbdd398ff1aa94e0b3ca7e4dafac1f373bbc889645ee6e442b78284 +size 79136 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87d31072cf1bea18293fba156dcdf676c12da898 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88dd8ab393c2b7fd7f1dbbff1a212698dccee8e05f3f8b225fdd8f7ff110f5f1 +size 98588 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..992d761ff6b17d26d8e92a4e63df4206e611324c --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:693ef906b4a3cdedcf12f36ce309cc8b6de911cc2d06ec7752547119ffeee530 +size 93063 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8768bc44b96253a42b3736ac6b2ebfe4b9a6f8dc --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06d0296ad6690beb85c52bf0df5a68b7fd6ffcf85072376a2090989d993dfbf8 +size 79729 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05fb0e416b83a4189619c894a9b0d0e06d0641b4 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfc70746d378f4df6befc38e3fc12a946b74f020ed84f2d14148132bbf6f90c7 +size 73581 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e59b0eb3d7cda3065c3e85e29ed0b75bcb8c3e08 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56d78dbcbb813c47bd13a9dfbf00bb298c3fc63a5d3b807d53aae6b207360d70 +size 79424 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..11c63767e7494af427722008d7eb2c93fbcd8399 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b97eed0c96feac3e037d1bc072dcb490c81cad37e9d90c459cb72a8ea98c6406 +size 78572 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9d45d0f296d87d726dfcdd742dadad727bda193 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0da47e47bf521578a1c1a72dee34cfec13a597925ade92f5552c5278e79a7a14 +size 62148 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c3282a63d04c032c509ee4fe15ec468d0b4b8f15 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29e6c31d2c3f255cd95fa119512305aafd60b7daf8472a73b263ed4dae7184ac +size 75286 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..78dee48f38cb0595157636b72531eb8a3967ea06 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af952a441862a0cadfc77c5f00b9195d0684e3888f63c5a3e8375eda36802957 +size 78315 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5765a113572929f7e99cd267a53c5457fb46272 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c68e15b9a7820af38f0a0d914454bb6b7add78a0946fd83d909fde2ca56f721 +size 62828 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..41b2570b6f773ad1a8c2b06587bf3921a11827ee --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ccabe0811539e04ddb47879a1a864b0ccb0637c07de89e33b71df655fd08529 +size 103833 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8b75de67cf95e5598bdbde4dc6bb6371e99c158b --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f629597738b0cfbb567acdecfbe37b18b9cbbdaf435ebd59cd219c28db199b38 +size 109762 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4d592f7c0490225cfea3546cc1d4ca5b5f384715 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e44cb9873856cd2c785905bc53ef5347ad375ed539b1b16434ceb1f38548db95 +size 109015 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6b380dddec3ccaa824f39d0d2f86a466904921f3 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4c260b40e202d4c2144e4885eb61a52009ee2743c8d04cd01f1beb33105b4c0 +size 95253 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf31081cae7339d043d9b37174d32c2f39bd46e7 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1c256473100905130d2f1d7e82cd2d951792fc5ff4180d2ba26edcc0e8d17f0 +size 103940 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..52a36e0065b37f73d28c6d8e3b332d62d36605b7 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65d4c53b714fd5d1ba73604abcdb5a12e9844d2f4d0c8436d192c06fe23a949c +size 108605 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b5d80a6ac1c5f37b1608a45b4588f15cc7e6f9a4 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4434b4fe8eba5ca92044385410c654c85ac1f67956eb8e757c509756435b33c +size 95080 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7b7f2992ee6dfe32f9476b48a7c29c93c40aa5b1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8e7e6ec0df2dbd9adbfa8b312cd6c5f04951df86eb4e4bd3f16acdb945b2d7b +size 106398 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2356d37020783c7f06a3c75b0ac93f53244b2a21 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a85ee7b1041a18f0adc7673283d25c7c19ab4fdbe02aa8b4aaf37d49e66c2fcc +size 109233 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..580ea14ffadcbef569d9947566cd6601ea5e3a31 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ba97ee2207a37249b3041a22318b8b85ac4ac1fcec3e2be6eabab9efcb526d7 +size 111988 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1bf37416ec0e9490efa4118ca15e3a5d3bc8f13b --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea908aa57fff053cb7689ebcdcb58f84cf2948dc34932f6cab6d7e4280f5929f +size 93144 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5083e2f231c4489ad251789cf17fd279e94de123 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c83f58489127bbf5aefd0d2268a3a0ad5c4f4455f1b80bcb26f4e4f98547a52 +size 106249 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..61d8dd667a41288b5466a590a4f7917113b4f8eb --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:daaf1a21a08272e0870114927af705fbac76ec146ebceefe2e46085240e729af +size 112103 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f2337cd550356b50310d8f66900bb49a5cc1f78 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57e4c53f861c5bab1f7d67191b95e75b75bdb236d5690afb44c4494a77472c29 +size 92118 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3e3139d8c7e911e5c1597f8d4cc211ba0bc62c8e --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2367c7a1fc593fe02cefa52da073ce65090b11e7916b1b6b6af0a703d574070 +size 79924 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7b96616f4dfcefeedc876ded2f7a05bbf5989198 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d93b4acc0d8fb7c150a6b5bcb50b346890af06557b3bfb19f4b86aa5d58c43ed +size 81568 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8417a80887bd923f92e37187b193ac38626f7815 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da12abfef15e8a462c01ad88947d0aca45bae8d972366dd55eecbeb5febb25cb +size 80924 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6444b8d7c844fc82804fba969839318da0c922cd --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42280a19b1792682a6743dfdbca8e4fb84c1354c999d68b103b1a9a460e47ca5 +size 63425 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ff18e6d4b0eabbac042eefdcfa4fadca6921c73 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9ff7cbd7f10f3f6e8734aab05b20a1fd1e0d4b00a0b44c14ed4f34a3f64642c +size 80202 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..959468fa7db4f3dde517e0ee4c59f7dc87ece238 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cc498c25f22f75f099a9b2cf54477e5204e3f91bd0a3d0d4498f401847f4ac8 +size 80296 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16054efff56ce617e17ff93d2c8029a7264f5951 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63eb14c5563ee17dadd4bed86040f229fdee7162c31a22c1be0cb543922442f7 +size 62355 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot135.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d73dfa10ebdeb97740b773790fd1fa313bb0556a --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a21943a12405413ce1e3d5d0978cd5fc2497051aadf1b32ee5c1516980b063a +size 79615 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d6a78bb7cc2274b2714c97706cac35bd601e933 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7e5a61e764ba79cb839b4bbe3bd7ecd279edf9ccc868914d44d562f2b5f6bb7 +size 81016 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot225.jpg new file mode 100644 index 0000000000000000000000000000000000000000..83b348efda25d253fa197b29138ee2034124c4b6 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot225.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58a09610d7ae581682b7d0c8ce27795cfd00321874fd61b3f2bbe735331048c6 +size 81638 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d8f611ec766ae6c4c0ea023cff1d17b0851def6 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot270.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65b3cbda1b6975f9a7da6d4b41fe2132816e64c739f1de6d76cd77495396bfc0 +size 68811 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot315.jpg new file mode 100644 index 0000000000000000000000000000000000000000..93754c8836ed911f93106c77c94047ef9c265343 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot315.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:228fc8095ed6614f1b4047dc66d2467c827229124c45a645c1916b80c6045c72 +size 78909 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot45.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bf980bf62804f81ed9384c09897091609b63d933 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot45.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b6a6c9006f4c7e38d5b4b232c787b4bf1136d3776c241ba6aadb3c60a5edf5e +size 81485 diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..522bc81a65f743be92e2ecc73d4c1ef41b8fdfbd --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot90.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2650e275a279ccb1024c2222e2955f2f7c8cef15200712e89bf844a64d547513 +size 68109 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d27fe8d6f813e1525d581e2665feff4704710e98 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65632a80a46b26408c91314f2c78431be3df41d8c740e1b1b2614c98f17091a0 +size 21890 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..26c82557bab1bb66463603bcbb23fb179ab1456a --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b95394b52baf736f467b28633536223abaf34b4219fe9c9541cd902c47539de7 +size 40252 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4e76cf942182757709234e6b44fdef9ea255035f --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de85408453e630cc7d28cf91df73499e8c98caffef40b1b196ef019d59f09486 +size 52742 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3e87b7854fded8fa9bd423e47b93da3e4e76414d --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35bc6e9a0c6b570e9d19c91d124422f17e26012908ee089eede78fd33ce7c051 +size 65170 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4cdcbbf48d27668ecc6e2c4ddf109a6cf1405e69 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6120fd5688d41a61a6d00605b2355a1775e0dba3b3ade8b5ad32bae35de897f8 +size 80670 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e60569c28a50e261f661a20816fe5ab4dd265754 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea90074a3c2680086fad36ea3a875168af69d56b055a1166940fbd299abb8e9f +size 96216 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..022cf34ec13b5accbfcaf5c6ac9adc9ed63f4387 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4371f89589919aa8bc7f4f093bf88a87994c45de67a471b5b1a6412eaa8e54e4 +size 112868 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ae239cdbed0fa5094181cfa85977a4cc5dd401b1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:394013dede3a72016cfefb51dd45b84439523d427de47ef43f45bbd2ce07d1b1 +size 18346 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7816ec7253dac16d13da7ab2bb05eb918cf90ec9 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd82ebb523f88dddb3cfcdeb20f149cb18a70062cd67eb394b6442a8f39ec9b1 +size 29287 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..75421d467f758d274c43191085d6b99ec6987741 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8927ec9505d0043ae8030752132cf47253ebc034d93394e4b972961a5307f151 +size 37883 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4b77784210fb86c4db0f587c70061dff44d72e06 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27f0c0465ba88b9296383ed21e9b4fa4116d145fae52913706a134a58247488b +size 44461 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..13b6b9d3b1d66e9356a0b97567f221882ea9bc7e --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1344408ec4e79932b9390a3340d2241eedc535c67a2811bda7390a64ddadaa75 +size 52812 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c1a7d765ecf17c513f6c716b8ba1b9421d8ad5c1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc4b253faf50d53822fdca31acfd269c5fd3df28ef6bd8182a753ca0ff626172 +size 63930 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..735e7c4475c84772447f96b0ca40703482044e94 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b0822b92d430c4d42076b0469a74e8cc6b379657b9b310b5f7674c6a47af6e8f +size 73652 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..06204bf54d2df8157c21f061496e644b6fa47b23 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bcef8bb651e67a969b5b17f55dd9b0827aaa69f42efe0da096787cae0ebc51e +size 19404 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d16aaf7bdbc453b42fc2d8c72fd9cf90bd5e3317 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdebf79043e7d889ecebf7f975d57d1202f49977dff193b158497ebcfc809a8f +size 33947 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c24e10ded3cca68b8851073f8eb19f36ee07acc2 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89bf1005e2bb111bf8bd9914bff13eb3252bf1427426bc2925c4f618bd8428b6 +size 44233 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd4f55bd2b3587f56a8c60ab0f4c6c0298d9b197 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b69b20d8695366ccbf25d8f40c7d722e4d3a48ecb99163f46c7d9684d7e7b2e5 +size 55447 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..da7c59860bdb2812ea138b55965ebc0711c80f5f --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f92f6225593780ad721b1c14dcb8772146433b5dbf7f0b927a1e2ac86bb8c4bd +size 65029 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..807d37f1572761bf8d876384e33577f26ca257b4 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f91146bc682fc502cbefd103e44e68a6865fbf69f28946de2f63e14aed277eb5 +size 78525 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c8384e6688273d4f1f31a80a3ef47a4e89a21b8 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:406f548692e16ac5449d1353d1b73b396285d5af82928e271186ac4692fa6b61 +size 92311 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..52b5e6a247240106fd5e5c886b7a22cf7de939c3 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f1667f20a37d2ebc6ecf94c418078492ab465277a97c313317ff01bbafad6ae +size 19279 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ecf2d221eb38ae1bebfce5e722f0999325ac519 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:107c4cc346f8b1f992e369aedf3abdafea4698539194f40eade5b5776eef5dd5 +size 36526 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6976776119ba2f20a89e7cee4e04fdea6894677a --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:407f9290be844f5831a52b926b618eff02a4db7bc6597f063051eeeaf39d151b +size 47460 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bcd235542d03cdbb3480a44ee181c01ca48a6558 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:967e490e89b602ecaf87863d49d590da4536c10059537aaec2d43f4038e3d771 +size 62509 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b987b854cdb4d833cf6aa74df88927a975c44362 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5afc12af54be3cb0dba0922a7f566ee71135cea8369b2e4842e7ae42fe64b53 +size 76306 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..877bb2c923b487518c640125c5197b0edf0e0bfb --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61f266200d9a1f2e9aefb1e3713aa70ccd0eed2d5c5af996aef120dcf18b0780 +size 91762 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4d0c041790cf9f3de32e608c04e589917d546556 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12aaae2e7a12b49ec263d9d2a29ff730ff90fbed1d0c5ce4789550695a891cbe +size 107595 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f9313cfc25615fe53efa05ffcdfe54b62ccdcdb --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e8dc5991d5af4b134fe7534c596aad5b88265981bba3d90156fff2fa3b7bd8 +size 19569 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e50f64360fe0805698b11202f12caaf6f939049c --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad8c1fc561fc5967d406c8ca8d0cce65344838edc46cc72e85515732a1d1a266 +size 34434 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ad560b70e0923b134bdf4fd35c75036b9c1feb6 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:132a899e4725b23eb0d0740bfda5528e68fdfcafef932b6f48b1a47e1d906bdb +size 44062 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7713f99755d6fb12f24d59ace7c6bb30af8d0140 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6ca571a824efff8b35a3aa234756d652aa9a54c33dfd98c2c5854eeb8bc2987 +size 56056 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd312a40e414dabc167d982e7283703e8d533bae --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66177f53a20ae430a858f7aeb2049049da83b910c277b5344113508a41859179 +size 65465 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f40be8bd5671519bfa985a592d4b4a2984582d4d --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6db2f05714a37dfd8ab21d52d05afd990b47271c1469e0c965d35dcd81e2c39 +size 77708 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9ecb09015abbcbf32907ff8584f9427161e3ed64 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a3a58d711cfe35f937f5b9e9914c7dcb427065d685a816c4dc5da610403d981 +size 92547 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3cd992394f9ace830fe1df7911f04c80e1abe0a --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bea2d09de8c92e49916c72f57ef8d0bc39fb63b51cdd2568706cbe01028a4480 +size 18119 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf490019463c6e01cc23fc2962c468c633a53aaa --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0312626c40b38df7b64471873e2efaebc90a5a34cfe9778da7b7a9a6793975be +size 28961 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f7ac21a0279982bb6f0a8ddf715f794b1577fc71 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ab9561483b9b02fd62304ce34c337107494050635e8a16bbf4b6426f602dbb +size 36523 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bfe516659198de7a9fd0d8bfe1f961804d9c0e00 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36fe588a2cb429b6655f5fff6253fec5d3272fa98125ffd1f84bdf0bab0f63e5 +size 44134 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..44574a69046cb62cedeffee56c34ae91cf160c7f --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04b725bd69e0078cc6becbe04958d8492d7845a290d72ade34db03eeaebb9464 +size 52784 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7ccebe81457b6d7a75583ad0bd3e1a080d3be65b --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2febc81d0fede4f6fe42ca4232395798b183358193e1d63a71264b1f57a126af +size 62516 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e78525e313a8a97bcbc16e035b73b4e69949b1f6 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0cd73a22d565e467441ed367a794237e5aeb51c64e12d31eba36e1109f009f5 +size 72003 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..508aa2e17c3eafd0cbfdbeb5f7e8383afd1eff49 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8fd1041d9aa8b9236dc9a0fe637d96972897295236961984402858a8688dea4 +size 19103 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f749213b5d42af17b5e8fa85be5f2a0099900482 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27d421ef6ed9d78b08d0b4796cfc604d3511a274e2c4be4bd3e910efcc236f07 +size 34317 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0af327e6da018baaf5e862a25b002fec00ee1481 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5480d172fc730905eb0de44d779ec45ecd11bd20907310833490eb8e816674e +size 45635 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87a0920fc65ff0f224724e182baffa5efabb2727 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c139a0636fd9bb1c78c6999e09bd5f7229359a618754244553f1f3ff225e503 +size 55257 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9429d3921a2f46d2bc2fa1b845c7b5ca8c34b4d --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c7f19e2d7bf07bd24785cd33a4b39a81551b0af81d6478a8e6caffd9c8702a4 +size 68433 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..816eb35ce7690a8f338698a32bf882f77c226515 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a54a79a52a214e511e7677570a548d84949989279e9d45f1ef2a3b64ec706d4 +size 83428 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..37c041c8c84ee938aa69efdba406855e71853961 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44b0fccc48763c01af6114d224630f0e9e9efb3545a4a44c0ea5d0bfc9388bc6 +size 97890 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a2fac4209480c24b542cdfae90539ca30cb34730 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a0621b195be135487a03013a9bbefd1c8d06219e89255cff3e9a7f353ddf512 +size 20967 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b9bb9631981f5525ccab16e8c57e9a5fc1cd8e23 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f463bf0ca416815588402256bc118649a4021a3613fda9fee7a498fe3c088211 +size 36897 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ef911222845da2a6392efd777b08888b386e36ff --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98f2e11e0dfdc72f62b23da6b67f233a57625c9bf1b3c71d4cb832749754cc8b +size 46863 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3b3ed4111f137eba6204f7678b434da1a38b8b5f --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fa19ae2f619f40298b7d5afe1dd9828a870f48fd288d63d5ea348b2a86da89c +size 58545 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fd502a98dd92d22c8246e9a2febad4e63fd49573 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44a1ff7cfbd900bbac8ad369232ff71b559a154e05f1f6fb43c5da24f248952d +size 70581 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e746d3e8df82d3022cb1e3192f51b067f8d22020 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:351b645bd3243f8676059d5096249bde326df7794c43c0910dc63c704d77dc28 +size 83757 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf18599607edcacf5f706fdb677fe63d2c6d2b37 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:592c76689f52ed651ee40640a85e826597c6a24b548b67183cc0cb556b4727e4 +size 97236 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf7917d922cdbd30d8a415c3f71c4f1970aaace5 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5240431398627407ddf7f59f7d213ccf910ddf749bfe1d54f421e917b072e53b +size 17605 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7f0604df7dde207aa3571ad68eeb37ea9c9646f6 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3436f60a6f7f69affcd67ac82c29845f31637420be8948baa938907c7e5f0a7 +size 28290 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fee6de7cbea4e841d6811d89a80baf7c8592b79c --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb8aa81f49cda021a0d92c97ff05f960274ab442f652d6041b163f86e2b35778 +size 37127 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6fcf0e854278c71762b85c0eff8b74a72a8cf78 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8e931c5c4dfe453042cca041f37fa9cdc168fbe83f7a2ccd3f7a8aced93a44e +size 45866 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2b263a84eb2fa2ecc1ff8e8e516ee47d8676904 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60fc4db91f0f34837601f066cd334840ce391b6d78b524ce28b7cc6a328227e2 +size 53956 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7ec774e14c9d0a09ba1d51f8834f5719a132b89e --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff223aa66f53e55ecd076c15ccb9395bee4d1baa8286d2acf236979787417626 +size 63961 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7c9d39b86fe5e72e9fe4061af79fe9bda7096a86 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c54568031f5558c7dfff0bb2ed2dcc679a97f89045512f44cb24869d7a2783d5 +size 74563 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..554e7360ffea8643309bc60a4397b6f0a33600b1 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d5e5f2d3cf4b6ebb1a33cf667622d5f10f4a6f6905b6d317c8793368768747b +size 18058 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b9c001354e3160e3849485eb79cfd9bd91bd559b --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7613d4ae7a19dd03b69cce7ae7d9084568b96de51a6008a0e7ba1684b2e1d18 +size 28921 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..654b114c83a846ddc7d3004d4ecb33bb48f9dd95 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:120f479a4a45f31a7d01a4401c7fa96a2028a2521049bd4c9effd141ac3c5498 +size 36174 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..548a5f7d5ed22f5a93fd58723d47a979e3d0a9fe --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d490f0baec9460ad46f4fbfab2fe8fa8f8db37019e1c055300008f3de9421ec +size 43542 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6ff24a52f698f00033277c0fb149c1e3ea92fea --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.7.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68697e9c9cda399c66c43efad0c89ea2500ef50b86be536cea9cdaeebd3742db +size 53236 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cad1a64516428929de5303d83eb9cd0cc6592f42 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.8.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:161e5e0aa1681fb3de437895f262811260bc7e7632245094663c8d527ed7634d +size 61754 diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f7e9b31632c13cf41edcfe060442ec44acdae202 --- /dev/null +++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.9.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf1f3fd9123cd4d9246dadefdd6b25e3ee8efddd070aa9642b79545b5845898d +size 72921 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/adam.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/adam.png new file mode 100644 index 0000000000000000000000000000000000000000..d203797ac4dbe32b56bf858bf574ccba301b2890 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/adam.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ab84505b1b857968e76515f3a70f8a137de06f1005ba67f7ba13560f26591a1 +size 204675 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cafe.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cafe.png new file mode 100644 index 0000000000000000000000000000000000000000..721255df1cc096fd276591904a6635bbc9335934 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cafe.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b75eb268829f64dc13acb5e14becd4fc151c867a0b7340c8149e6d0aa96b511 +size 677870 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cat.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cat.png new file mode 100644 index 0000000000000000000000000000000000000000..44a96af0b33b709f46d89d15ad9758f0207eefee --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cat.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b9e70a57eb37ba42064edfd72d85103a3824af4e14fb6e8196453834ceb8eaf +size 858661 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/dum.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/dum.png new file mode 100644 index 0000000000000000000000000000000000000000..d4bb86ef5af0caa7ed1bd233f45f8e79a44a1c45 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/dum.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c025b6776edf6e75dee24bba0e2eb2edd9541129a5c615dcb26f63da9d3e2c5 +size 1227154 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/face.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/face.png new file mode 100644 index 0000000000000000000000000000000000000000..af9eb4b47e1e1adb17f29a6e3d06cd03b51b00dd --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/face.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8819f9fb3a5d3c2742685ad52429319268271f8d68e25a22832373c6791ef629 +size 1653453 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/fox.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/fox.png new file mode 100644 index 0000000000000000000000000000000000000000..0ee9996c5a6778e44ec12d68fabf4a2d8dc45cef --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/fox.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4ea0030c9270f97c5d31fa303f80cb36ccfbdbc02511c1e7c7ebee2b237360f +size 847938 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/girl.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/girl.png new file mode 100644 index 0000000000000000000000000000000000000000..33c2c22d3a73ebc78da6c143a98c9b18468ff9b5 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/girl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:699cbd2b94405170aac86aafb745a2fe158d5b3451f81aa9a2ece9c7b76d1bca +size 1089404 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/graf.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/graf.png new file mode 100644 index 0000000000000000000000000000000000000000..70eb91f5aa6f35270889b83edc84c6a71de08846 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/graf.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33d874bf68a119429df9ed41d65baab5dbc864f5a7c92b939ab37b8a6f04eaa5 +size 947738 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/grand.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/grand.png new file mode 100644 index 0000000000000000000000000000000000000000..dc07235113020505da8521231982426ab737a8c2 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/grand.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e86a53256e15f7fa9ab27c266a7569a155c92f923b177106ed375fac21b64be +size 1106556 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/index.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/index.png new file mode 100644 index 0000000000000000000000000000000000000000..b1400bb6e79c8ac25fd5555648aca002d6af46c4 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/index.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44da42873d5ef49d8ddde2bc19d873a49106961b1421ea71761340674228b840 +size 764609 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/mag.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/mag.png new file mode 100644 index 0000000000000000000000000000000000000000..64acb056fcd7fda7a95fc3844772f1f628653489 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/mag.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba4102c8d465a122014cf717bfdf3a8ec23fea66f23c98620442186a1da205ff +size 109706 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/pkk.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/pkk.png new file mode 100644 index 0000000000000000000000000000000000000000..44d7844ecea3980342b5564935f9406bc3bac23a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/pkk.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:965fced32bbcd19d6f16bbe5b8fa1b4705ce2e1354674a6778bf26ddaf73a793 +size 1030148 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/shop.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/shop.png new file mode 100644 index 0000000000000000000000000000000000000000..f58917a71cb265bcd4e404147c28444b811cebe3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/shop.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20a9b74db30bb47f740cdc9ac80d55c57169dcaaeaa9e6922615680ac8651562 +size 950858 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/there.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/there.png new file mode 100644 index 0000000000000000000000000000000000000000..b39c566fa6861e9e3d9c4521513fdc74bc6241b0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/there.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0aa0ec27f2a82680c1f69d507b6cf1832fa4543c0a3cee4da0c9e177bf8c5cf4 +size 2194678 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/vin.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/vin.png new file mode 100644 index 0000000000000000000000000000000000000000..0c96c24a8181a0313caa0e05c3c521194c7155c2 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/vin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a5420934a1fc2061b30b69cecb47be520415cf736c1ebe0ecb433afe9f4d3ea +size 971107 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/adam.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/adam.png new file mode 100644 index 0000000000000000000000000000000000000000..df6cee2629c7fe032fbef3ac81521cb28551f4c8 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/adam.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17f571379d9764726f9ba43cac717d4f29e7a21240bed40a95a2251407ae2c5d +size 170216 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cafe.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cafe.png new file mode 100644 index 0000000000000000000000000000000000000000..68ab7b57cd637ba9d704bca9b9867c1ee6a9a4e0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cafe.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69aaaa6832d73e3b482e7c065bad3fce413dd3d9daa21f62f5890554b46d8d62 +size 668573 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cat.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cat.png new file mode 100644 index 0000000000000000000000000000000000000000..16844477316e600d83398302f7b1b938a24d98f6 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cat.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6f97f05c6f554fc02a4f11af84026e02fd8255453a4130a2c12619923c669ae +size 855468 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/dum.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/dum.png new file mode 100644 index 0000000000000000000000000000000000000000..6075f218a49b1ab906ac6b8a229c6d374c1b5fd1 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/dum.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e73d572607d2bd6283f178f681067dae54b01ceea3afcd9d932298d026e7202 +size 1124916 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/face.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/face.png new file mode 100644 index 0000000000000000000000000000000000000000..4d8e187deae660c6930b28440f44349334c935c6 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/face.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abc729a77234d4f32fa7faf746a729596fc1dd8c3bcf9c3e301305aedbd3ba1f +size 1052127 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/fox.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/fox.png new file mode 100644 index 0000000000000000000000000000000000000000..faf3b8f755a711d53f0b3f7c59d5637b27f3d879 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/fox.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ceec2ea119120128d44f09aef9354cc39098cf639edb451894f6d4c099950161 +size 737697 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/girl.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/girl.png new file mode 100644 index 0000000000000000000000000000000000000000..3e196cb63dd8977f0ab483e7f8e83da4054f87e1 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/girl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b458d84eaf32b19b52f8871d03913cb872ce87d55d758a1dfc962eff148f929 +size 1044256 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/graf.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/graf.png new file mode 100644 index 0000000000000000000000000000000000000000..4614b40313c575a31bccef6fca21e955ba98dc85 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/graf.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3f3bce412ad9c05e8048a5fdb8a2c32dbba4a6a37db02631b5fff035712d556 +size 1037985 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/grand.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/grand.png new file mode 100644 index 0000000000000000000000000000000000000000..855963559608e2599620426bc341bbccee82d0d8 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/grand.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8af49e086e21804e9ec6d3a117cad972187f6ea8bc05758893e16d56df720631 +size 1156742 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/index.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/index.png new file mode 100644 index 0000000000000000000000000000000000000000..094bba34dd52c273985b48650b403d06b4105354 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/index.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74f4f0cf2ff09fe501e92ebbe89cc5c5a1e0df00aec1ce59f5aebed8ecb6ac5e +size 764123 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/mag.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/mag.png new file mode 100644 index 0000000000000000000000000000000000000000..0583f0bf356cc2b2eaa61394561819b9fcdc2ef6 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/mag.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fcb58c53ce102e0302f321027995dfba64fe1195f481ad48be12e70f80fe66b +size 114421 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/pkk.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/pkk.png new file mode 100644 index 0000000000000000000000000000000000000000..ccdc42a0aeee171029441e59eda873d154046ae0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/pkk.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f608cf8f1498bf65452f54652c755f7364dd0460ee6a793224b179c7426b2c6 +size 1030783 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/shop.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/shop.png new file mode 100644 index 0000000000000000000000000000000000000000..8a8a62b6362e04ae8650c7fcb7efc85d84c2deca --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/shop.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:863f7ecf093d3f7dac75399f84cf2ad78d466a829f8cd2b79d8b85d297fe806c +size 787662 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/there.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/there.png new file mode 100644 index 0000000000000000000000000000000000000000..b6cdefd446bcecfb197ab24860512de6c78fc380 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/there.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:597a5750d05402480df0ddc361585b842ceb692438e3cc69671a7a81a9199997 +size 2208079 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/vin.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/vin.png new file mode 100644 index 0000000000000000000000000000000000000000..cf45bf5859e3392daa77ee273adafb89ed7eb29a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/vin.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edfc46c5e1ad9b990a9d771cff268e309f5554360d9017d5153c166c73def86e +size 1173916 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/adam.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/adam.txt new file mode 100644 index 0000000000000000000000000000000000000000..62fc9cce248fa1a8152402c90cd5ea11e8df3571 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/adam.txt @@ -0,0 +1,3 @@ +-0.475719 0.28239 -808.118 +-0.0337691 -2.62434 -12.0121 +0.000234167 8.96558e-05 -2.96958 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cafe.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cafe.txt new file mode 100644 index 0000000000000000000000000000000000000000..a7cb081273eb3092eb5bfa927dc167b548c21f8b --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cafe.txt @@ -0,0 +1,3 @@ +-15.8525 20.632 -111.81 +-0.866106 -0.309191 -335.089 +0.00197221 0.00575211 -5.53643 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cat.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cat.txt new file mode 100644 index 0000000000000000000000000000000000000000..68070704b87b0e675be8130a9f9233f5206fd22a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cat.txt @@ -0,0 +1,3 @@ +-39.2129 0.823102 5118.35 +1.39233 -0.138102 -1711.78 +0.00592927 0.00421801 -11.1641 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/dum.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/dum.txt new file mode 100644 index 0000000000000000000000000000000000000000..9fcc4dc9b2e9f94c6b340c1fe2132e0c841ab0d3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/dum.txt @@ -0,0 +1,3 @@ +-0.778836 0.981896 -748.395 +3.80065 -6.24622 -695.732 +0.00344236 0.000197217 -4.13445 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/face.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/face.txt new file mode 100644 index 0000000000000000000000000000000000000000..44fc3af247ae00c5d20fdd07b712844c4150fe95 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/face.txt @@ -0,0 +1,3 @@ +-0.260478 -0.124431 -3232.91 +-0.0968348 -2.12523 -254.238 +4.85011e-05 -0.000178626 -4.2015 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/fox.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/fox.txt new file mode 100644 index 0000000000000000000000000000000000000000..95663e958f1fc7e10ec61cbda63a506a7bf5188f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/fox.txt @@ -0,0 +1,3 @@ +-1.39944 -4.86818 -1653 +3.74922 -40.7447 2368.02 +6.58821e-05 -0.016356 -6.58091 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/girl.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/girl.txt new file mode 100644 index 0000000000000000000000000000000000000000..b4e1dabd05273f50281a0d8c06332c90d0e6b5d5 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/girl.txt @@ -0,0 +1,3 @@ +1.80543 -28.3029 -3740.88 +1.55303 -18.1493 -474.282 +0.00279485 -0.0334053 -4.76287 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/graf.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/graf.txt new file mode 100644 index 0000000000000000000000000000000000000000..c088d86285144f45038af4018183d5fc61e332dc --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/graf.txt @@ -0,0 +1,3 @@ + 4.2714590e-01 -6.7181765e-01 4.5361534e+02 + 4.4106579e-01 1.0133230e+00 -4.6534569e+01 + 5.1887712e-04 -7.8853731e-05 1.0000000e+00 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/grand.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/grand.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8cad131c7f7c99e21ce61663588aa6a97c60553 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/grand.txt @@ -0,0 +1,3 @@ +-0.964008 0.742193 -686.412 +3.65869 -3.74358 -1880.34 +0.00239552 0.000705499 -3.33892 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/index.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/index.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a4f3aa9a93608b06eeadb1c1006f27a03908680 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/index.txt @@ -0,0 +1,3 @@ +0.420191 -1.04493 1337.46 +-2.95716 8.01143 1193.82 +-0.00241706 -0.000152038 5.08127 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/mag.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/mag.txt new file mode 100644 index 0000000000000000000000000000000000000000..f88872a57071cad157cf98018bb3cfdfb75adb7e --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/mag.txt @@ -0,0 +1,3 @@ +-0.296887 -16.5762 4405.55 +0.672497 0.906427 398.264 +-0.000777876 0.00319639 3.31196 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/pkk.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/pkk.txt new file mode 100644 index 0000000000000000000000000000000000000000..840ea53a262a6d312b32178b06a6a3e5932fca88 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/pkk.txt @@ -0,0 +1,3 @@ +-2.63017 -5.53715 -1101.86 +-2.06413 -10.0132 1466.65 +-0.00338896 -0.00999799 -2.34239 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/shop.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/shop.txt new file mode 100644 index 0000000000000000000000000000000000000000..2cb0af6b4b1973004cc5f3b2ef9d5221ccc0f52a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/shop.txt @@ -0,0 +1,3 @@ +-0.0379327 -0.792592 -2778.33 +2.80747 -11.5896 919.535 +0.00281406 0.000153738 -8.54098 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/there.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/there.txt new file mode 100644 index 0000000000000000000000000000000000000000..5c2031cce2da4817c9a8d122cd4391452d3f6c13 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/there.txt @@ -0,0 +1,3 @@ +0.314825 0.115834 690.506 + 0.175462 0.706365 14.4974 +0.000267118 0.000126909 1.0 diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/vin.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/vin.txt new file mode 100644 index 0000000000000000000000000000000000000000..ea4b033024b230a7aadff1bd8034cf1194b10122 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/vin.txt @@ -0,0 +1,3 @@ +4.90207 -19.4931 2521.64 +7.0392 -28.4826 5753.82 +0.00953653 -0.0350753 4.24836 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/README.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/README.txt new file mode 100644 index 0000000000000000000000000000000000000000..fccf33fdccf7d4ec7c55b62f82e556ec47b0ead4 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/README.txt @@ -0,0 +1,47 @@ +Welcome to WxBS version 1.1 -- Wide (multiple) Baseline Dataset. + +It contains 34 very challenging image pairs with manually annotated ground truth correspondences. +The images are organized into several categories: + +- WGALBS: with Geometric, Appearance and iLlumination changes +- WGBS: with Geometric (viewpoint) changes +- WLABS: with iLlumination and Appearance changes. The viewpoint change is present, but not significant +- WGSBS: with Geometric and Sensor (thermal camera vs visible) changes +- WGABS: with Geometric and Appearance changes. + +Compared to the original dataset from 2015, v.1.1 contains more correspondences, which are also cleaned, and 3 additional image pairs: WGALBS/kyiv_dolltheater, WGALBS/kyiv_dolltheater2, WGBS/kn-church. +We also provide cross-validation errors for each of the GT correspondences. +They are estimated in the following way: + +- the fundamental matrix F is estimated with OpenCV 8pt algorithm (no RANSAC), using all points, except one. +F, _ = cv2.findFundamentalMat(corrs_cur[:,:2], corrs_cur[:,2:], cv2.FM_8POINT) +Then the symmetrical epipolar distance is calculatd on that held-out point. We have used kornia implementation of the symmetrical epipolar distance: + + +From Hartley and Zisserman, symmetric epipolar distance (11.10) +sed = (x'^T F x) ** 2 / (((Fx)_1**2) + (Fx)_2**2)) + 1/ (((F^Tx')_1**2) + (F^Tx')_2**2)) + +https://kornia.readthedocs.io/en/latest/geometry.epipolar.html#kornia.geometry.epipolar.symmetrical_epipolar_distance + + +The labeling is done using [pixelstitch](https://pypi.org/project/pixelstitch/) + +There are main intended ways of using the dataset. +a) First, is evaluation of the image matchers, which are estimating fundamental matrix. One calculates reprojection error on the GT correspondences and report mean error, or the percentage of the GT correspondences, which are in agreement with the estimated F. For more details see the paper[1] + +b) For the methods like [CoTR](https://arxiv.org/abs/2103.14167), which look for the correspondences in the image 2, given the query point in image 1, one can directly calculate error between returned point and GT correspondence. + + +*** +If you are using this dataset, please cite us: + +[1] WxBS: Wide Baseline Stereo Generalizations. D. Mishkin and M. Perdoch and J.Matas and K. Lenc. In Proc BMVC, 2015 + +@InProceedings{Mishkin2015WXBS, + author = {{Mishkin}, D. and {Matas}, J. and {Perdoch}, M. and {Lenc}, K. }, + booktitle = {Proceedings of the British Machine Vision Conference}, + publisher = {BMVA}, + title = "{WxBS: Wide Baseline Stereo Generalizations}", + year = 2015, + month = sep +} diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..0066f66d4bd858d6c5b47955fe4763a19eb60dd7 Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/01.png new file mode 100644 index 0000000000000000000000000000000000000000..69dd5333685a7d65278630e65d7b07f02fc701d5 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:850e7ce7de02bf28d5d3301d1f5caf95c6c87df0f8c0c20c7af679a1e6a1dd8e +size 845171 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/02.png new file mode 100644 index 0000000000000000000000000000000000000000..350c4b1d0559fc2fb78734d4f782073b7e685c15 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b31024c1d52187acb022d24118bc7b37363801d7f8c6132027ecbb33259bc9e +size 615221 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..a52583361953fd2c94cb5e3477ce5fe806a56a08 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/corrs.txt @@ -0,0 +1,49 @@ +4.907274400000000014e+02 3.917342500000000172e+01 4.622376800000000117e+02 1.865405900000000017e+01 +4.199036600000000021e+02 1.615936800000000062e+02 3.848875699999999824e+02 1.613663899999999956e+02 +4.948897400000000175e+02 1.302602300000000071e+02 4.316474000000000046e+02 1.558661099999999919e+02 +5.940666800000000194e+02 1.390053499999999929e+02 5.624503499999999576e+02 1.447967299999999966e+02 +4.473480099999999879e+02 2.174977000000000089e+02 3.973540699999999788e+02 2.388423799999999915e+02 +5.604767500000000382e+02 2.211028900000000021e+02 5.063893199999999979e+02 2.504398300000000006e+02 +1.570469300000000032e+02 2.428232000000000141e+02 1.659867399999999975e+02 1.769992300000000114e+02 +1.283247700000000009e+02 2.466919000000000040e+02 1.308167500000000132e+02 1.806334600000000137e+02 +1.856345599999999934e+02 2.481083900000000142e+02 1.985409400000000062e+02 1.834373699999999872e+02 +3.945598299999999767e+02 3.011998500000000263e+02 3.165450599999999781e+02 3.465213200000000029e+02 +4.044084700000000225e+02 2.958942599999999743e+02 3.425922899999999913e+02 3.308400700000000256e+02 +4.486278899999999794e+02 4.079485399999999800e+02 3.311905699999999797e+02 4.730156900000000064e+02 +4.880407599999999775e+02 4.070106700000000046e+02 3.615463899999999740e+02 4.750086600000000203e+02 +8.308690199999999493e+01 2.656556800000000180e+02 6.541268700000000536e+01 2.314867700000000070e+02 +1.554153399999999863e+02 3.250390100000000189e+02 1.665729100000000074e+02 2.826928100000000086e+02 +4.981625900000000229e+02 2.744976300000000151e+02 4.228344999999999914e+02 3.175723100000000159e+02 +4.522427599999999757e+02 2.984249800000000050e+02 3.946770700000000147e+02 3.311724100000000135e+02 +6.605721899999999778e+02 1.820804599999999880e+02 5.863873800000000074e+02 2.251841799999999978e+02 +4.103960099999999898e+02 3.502666500000000269e+02 3.240082199999999943e+02 4.019584899999999834e+02 +5.709502800000000207e+02 2.934323600000000170e+02 5.069851699999999823e+02 3.328039899999999989e+02 +1.218060499999999990e+02 2.819134199999999737e+02 1.248734800000000007e+02 2.297389900000000011e+02 +4.939615600000000200e+02 3.342423400000000129e+02 4.141000999999999976e+02 3.836088899999999740e+02 +3.687382882065904823e+02 3.495395049102415328e+02 2.896320927956163587e+02 4.006195633380069694e+02 +3.910059262681439236e+02 3.707051060659499058e+02 3.076935963192913732e+02 4.208602672369242441e+02 +3.975209120765460398e+02 3.629260185335294295e+02 3.133447415935700633e+02 4.132146001011354315e+02 +3.849771334305181085e+02 3.630232571276847011e+02 3.022640645851805061e+02 4.129929865609676654e+02 +5.620418645411094758e+02 3.036978250872198828e+02 4.978339578685253741e+02 3.451610927493376835e+02 +5.528727291807808797e+02 2.946790034213228182e+02 4.903915169107766019e+02 3.335120547285135331e+02 +5.608393549856565414e+02 2.664200288681786333e+02 5.007462173737313833e+02 3.035804987027848370e+02 +5.969146416492446861e+02 2.474805025813681141e+02 5.547848104147767572e+02 2.742961106073569795e+02 +6.084887961204792646e+02 2.737853991069012523e+02 5.619036669830582014e+02 3.040658744383520116e+02 +5.981171512046976204e+02 2.816017112173453825e+02 5.513871743253697559e+02 3.152295358749751131e+02 +4.966124691639038247e+02 2.254211022753156897e+02 4.251663048820389577e+02 2.657820691066586960e+02 +4.968032309070188717e+02 2.326700485136870213e+02 4.251663048820389577e+02 2.722625024200451662e+02 +5.078674120076908594e+02 2.578505986048716068e+02 4.311067020859765080e+02 3.022345064944572641e+02 +5.078674120076908594e+02 2.626196421827474978e+02 4.313767201407009111e+02 3.076348675889458946e+02 +4.861205732925768643e+02 2.631919274120925820e+02 4.100452938174706787e+02 3.065547953700481685e+02 +5.143533112736021167e+02 2.296178606238464681e+02 4.443375867674737378e+02 2.692923038180763342e+02 +5.048152241178503346e+02 2.403005182382884186e+02 4.316467381954253710e+02 2.809030801712269749e+02 +4.905080933842227182e+02 1.445381231945409013e+02 4.521681103544822804e+02 1.521044680676722862e+02 +5.090119824663810846e+02 1.451104084238859855e+02 4.699893019662949314e+02 1.521044680676722862e+02 +5.154978817322922851e+02 1.496886902586468295e+02 4.845702769214142904e+02 1.526445041771211493e+02 +4.811607679715859831e+02 2.319070015412268617e+02 4.127454743647149940e+02 2.682122315991786081e+02 +1.659310260943734363e+02 1.904096720602631763e+02 1.750965926567237716e+02 1.086081899286738235e+02 +1.514724884758092287e+02 1.897717954006206185e+02 1.568970252261333371e+02 1.083940773706668779e+02 +1.642300216686599867e+02 1.191801117335130584e+02 1.731695796346612610e+02 9.259963013450544622e+00 +9.197711414192471580e+01 2.817466503197928205e+02 7.713222472068414959e+01 2.509560380918086082e+02 +8.139120388717148558e+01 2.891666808721899997e+02 6.421399252360467358e+01 2.628805601198819772e+02 +6.942022126263745463e+01 2.819445178011901021e+02 5.052287463952045243e+01 2.516185115378127080e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..897f08d9475b367a75f4c0366d8dc21d69560ab4 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/crossval_errors.txt @@ -0,0 +1,49 @@ +5.946305398998352754e-01 +6.151500492287013122e-01 +4.435131637366677704e-01 +8.607236886172928703e-01 +1.119487091546507093e-02 +1.842455765244689070e-01 +1.139326684695094460e+00 +1.741695769999793031e+00 +1.199652052219880360e+00 +1.070740677484232517e+00 +4.284088135921398921e-02 +2.060204916077549853e-01 +4.055000220707856706e-01 +6.973383799953524198e-01 +2.092896295374734983e+00 +1.402160800861952161e+00 +1.016634654724758224e+00 +1.492331175763036866e+00 +2.803261140423741193e-01 +1.549618715189659435e+00 +2.228131674171225374e+00 +1.047571352924588028e+00 +4.280705770200255778e+00 +1.020399628357264721e-01 +5.543407989816054871e-01 +4.932068148368541627e-02 +7.650761824905847330e-01 +5.087262963715805109e-01 +1.022079507471783000e+00 +8.715133918375695954e-01 +2.017157710400199200e+00 +2.755255421845337782e+00 +1.540080127238647512e+00 +1.175615709542446596e-01 +1.631398780409486493e-01 +5.850237557411903655e-01 +8.175730992894985061e-01 +1.138357037213331369e+00 +8.535338399374122753e-01 +2.322429340540620224e+00 +3.717534098826499322e-01 +1.258508672327652844e+00 +9.834160308617387880e-01 +1.191211366876187006e+00 +3.598025256102476699e-01 +1.039241620045164405e+00 +2.476248683496967473e+00 +2.857957423093941074e+00 +3.709824113185862249e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/01.png new file mode 100644 index 0000000000000000000000000000000000000000..102b10b6b368dc04440b2bf5da2a12be205b00e9 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e79c3270fd5365f355307e7e944f05dfdc43691f87b5aff4506607db42c2a56 +size 616896 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/02.png new file mode 100644 index 0000000000000000000000000000000000000000..a4f26819f208e9693f6dd59ec478fe1d10143837 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d7e20317377179b99d143d6197f0647d5882edfba14c326f354fe33ee912362 +size 523819 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..1717b77ae6409871a49fb834330096a9d4a1e7e8 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/corrs.txt @@ -0,0 +1,31 @@ +8.218366899999999475e+01 1.530866499999999917e+02 2.595849000000000046e+02 2.658549699999999802e+02 +2.744326800000000048e+02 1.119666299999999950e+02 3.892155200000000264e+02 2.346591600000000142e+02 +2.210166600000000017e+02 1.117554300000000040e+02 3.450251400000000217e+02 2.343074599999999919e+02 +2.451667199999999980e+02 1.114037299999999959e+02 3.629618300000000204e+02 2.341902200000000107e+02 +3.003384300000000167e+02 1.625225299999999891e+02 4.159758299999999736e+02 2.712874699999999848e+02 +2.449410799999999995e+02 1.619131099999999890e+02 3.759574400000000196e+02 2.722447000000000230e+02 +2.866947200000000180e+02 3.839937800000000152e+02 5.190999699999999848e+02 4.356803499999999758e+02 +3.922720800000000168e+02 1.174523999999999972e+02 4.723207800000000134e+02 2.384106199999999944e+02 +3.775006799999999885e+02 1.309342300000000137e+02 4.621214800000000196e+02 2.484926900000000103e+02 +3.396873899999999935e+02 1.175705700000000036e+02 4.378811600000000226e+02 2.388010600000000068e+02 +5.545821399999999812e+01 1.701533900000000017e+02 2.369180700000000002e+02 2.824644000000000119e+02 +9.402142899999999770e+01 1.694848800000000040e+02 2.677825799999999958e+02 2.776578400000000215e+02 +1.542611699999999928e+02 1.653593899999999906e+02 3.074880099999999743e+02 2.772470000000000141e+02 +2.136086400000000083e+02 1.141107899999999944e+02 3.302527099999999791e+02 2.362122400000000084e+02 +2.361495099999999923e+02 3.121274900000000230e+02 4.639961799999999812e+02 3.821199599999999919e+02 +2.772462800000000129e+02 2.915795699999999897e+02 4.827351699999999823e+02 3.661461400000000026e+02 +8.007346900000000289e+01 1.227232300000000009e+02 2.575919299999999907e+02 2.435806399999999883e+02 +3.099508300000000105e+01 9.362195800000000645e+01 2.155214400000000126e+02 2.217527800000000013e+02 +2.801296500000000265e+02 3.661743200000000229e+02 5.140589400000000069e+02 4.218468199999999797e+02 +2.433183999999999969e+02 3.944275499999999965e+02 5.046802700000000073e+02 4.427143500000000245e+02 +6.178881200000000007e+01 1.362729699999999866e+02 2.441865300000000047e+02 2.544456500000000005e+02 +2.756259297043956735e+02 3.740927471033548954e+02 5.114641728026981013e+02 4.273074362362154943e+02 +8.221162108383408906e+01 4.056540777987604542e+02 4.373557405352174214e+02 4.522597703330104650e+02 +7.411897218757627570e+01 4.113189320261409421e+02 4.341119371026341014e+02 4.560026204475296936e+02 +1.251026602340004672e+02 3.761159093274194447e+02 4.495823842426468673e+02 4.300521929868630195e+02 +1.360277362439485103e+02 3.684278928759745213e+02 4.538242810391020612e+02 4.243131561446001569e+02 +3.915916690636100839e+02 1.310116728485639044e+02 4.712393946580673401e+02 2.479437873807888479e+02 +3.764727575077063193e+02 1.158927612926601398e+02 4.608862906124819574e+02 2.377897814899262414e+02 +3.920793758879941038e+02 1.617372027847554818e+02 4.792033208469791816e+02 2.708400751739104635e+02 +4.655489607273053707e+02 1.858487678854069998e+02 5.446299327018380154e+02 2.889468305880199068e+02 +4.021047738999301373e+02 2.268103059826929666e+02 5.263277510484970207e+02 3.166409212476804669e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..b8e8634a7e03b482b4790715a0f21e6220656125 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/crossval_errors.txt @@ -0,0 +1,31 @@ +1.065223937182152358e+00 +5.736516321493103643e-01 +4.019874183288125180e-01 +6.051618580623756294e-02 +1.331781109861698509e+00 +4.707776137115839421e-01 +1.562036957800347237e+00 +5.946548909007817185e-01 +7.176612972602752771e-01 +1.036625416881644890e+00 +4.840386519356732364e+00 +3.143029902439770762e+00 +3.486454358398270337e+00 +1.060173865479107747e+00 +5.061210528868288067e-02 +1.198614553697012336e-01 +1.133641006763047798e+00 +1.628127580495658533e+00 +1.076465803298405577e+00 +6.493052438943768268e-01 +1.691393384139549205e+00 +6.172534509656668611e-01 +5.961055533880751378e-01 +2.768433902281529857e+00 +1.021596472702978486e+00 +1.800412927080836445e+00 +3.333198235820573618e-01 +1.588410312156819382e+00 +2.664948685440807763e-01 +3.174975276305068483e+00 +2.025259979974938673e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/01.png new file mode 100644 index 0000000000000000000000000000000000000000..548bcc64b3b7d0b5f025e10a38dc1167ee0a1033 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e263b6c695097b44806b17f0b76cc0ae0f92fe3909293cc78755e3caecf5d0b +size 1452608 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/02.png new file mode 100644 index 0000000000000000000000000000000000000000..26992cc042c17e10c1c3bdc07fa862327e734978 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:920a6f7a68df8766dd8a4ed3f733d67b6432770842fb606467a6f58276afdea1 +size 1368978 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..c220d9b509c047ca449a4a4b6f3121b06af474f3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/corrs.txt @@ -0,0 +1,38 @@ +5.939102000000000317e+02 2.418589599999999962e+02 4.982037500000000136e+02 1.194755399999999952e+02 +6.916765400000000454e+02 2.234832499999999982e+02 6.342270399999999881e+02 1.034648899999999969e+02 +4.017696900000000255e+02 4.162991499999999974e+02 4.903268100000000231e+02 3.427198700000000144e+02 +3.992128599999999778e+02 5.331783199999999852e+02 4.892717099999999846e+02 5.246871999999999616e+02 +3.671711700000000178e+02 5.252994800000000168e+02 4.202384299999999939e+02 5.129918800000000374e+02 +2.765707600000000070e+02 4.378662699999999859e+02 1.672225400000000093e+02 3.733789699999999812e+02 +5.865748200000000452e+02 3.992576099999999997e+02 5.190209599999999455e+02 3.079799199999999928e+02 +5.466300899999999956e+02 3.986249399999999810e+02 4.676822900000000232e+02 3.068783000000000243e+02 +6.762335600000000113e+02 3.514155400000000213e+02 6.263054899999999634e+02 2.511592899999999986e+02 +3.721860899999999788e+02 2.715075699999999870e+02 2.301838999999999942e+02 1.436160800000000108e+02 +4.851462399999999775e+02 3.550274699999999939e+02 3.759133499999999799e+02 2.528731799999999907e+02 +6.853274400000000242e+02 2.796412799999999947e+02 6.251331599999999753e+02 1.695877499999999998e+02 +3.793781299999999987e+02 4.288431100000000242e+02 4.532810900000000061e+02 3.627667700000000082e+02 +4.257075899999999820e+02 5.276683500000000322e+02 5.130700699999999870e+02 5.100330400000000282e+02 +3.989784000000000219e+02 5.130141899999999850e+02 4.866925800000000208e+02 4.951444099999999935e+02 +3.215433800000000133e+02 3.185416200000000231e+02 1.649607300000000123e+02 2.048264900000000068e+02 +8.231494699999999511e+02 3.366846400000000017e+02 8.477356399999999894e+02 2.366834699999999998e+02 +8.274871000000000549e+02 2.987010599999999840e+02 8.561764399999999569e+02 1.924865200000000129e+02 +5.764467172506857651e+02 2.733143446995878207e+02 4.803264665374170477e+02 1.560546913685867025e+02 +5.413348219803513075e+02 2.441738601895142722e+02 4.316845558332415749e+02 1.199110493870118717e+02 +5.738192965161708798e+02 2.494287016585439289e+02 4.762729739787357630e+02 1.276802434578176815e+02 +4.947578180503157341e+02 2.909897205499603388e+02 3.837182272221796211e+02 1.756465720688796068e+02 +5.071783524316585385e+02 3.748283276240243822e+02 4.097281378070512119e+02 2.790106323152525647e+02 +5.723861579337083185e+02 2.897954383979081285e+02 5.043096308429479677e+02 1.749709899757660594e+02 +6.676898736674735346e+02 2.924228591324229001e+02 6.120649746945589413e+02 1.820646019534583218e+02 +3.377939878780084655e+02 3.703309743465495671e+02 1.918655051113914283e+02 2.736075751686765898e+02 +3.210982358910493417e+02 3.827509849709947503e+02 1.703373326379488049e+02 2.901381361750700307e+02 +2.563512952586955862e+02 3.888591869174432532e+02 7.230726155352260776e+01 2.982112008526110003e+02 +3.756651435405149186e+02 4.833721106941258654e+02 4.502639317766756903e+02 4.503528604518788825e+02 +3.910914280535924945e+02 5.100965190759362713e+02 4.729039071343397609e+02 4.902707117403917891e+02 +3.956541319236577010e+02 5.083583461730543149e+02 4.776702177359532584e+02 4.843128234883749315e+02 +3.895705267635707969e+02 4.492604674750670029e+02 4.737975903721422810e+02 3.952423941207230200e+02 +3.889951219559405331e+02 4.280560699476575905e+02 4.665334998202267798e+02 3.603993349145116554e+02 +2.811901416325709420e+02 4.424291250164491203e+02 1.737224299407419039e+02 3.796104651314472562e+02 +2.423185076331632786e+02 4.371700215929998308e+02 1.180053812948370933e+02 3.731101427894250264e+02 +2.720438748091809202e+02 4.632368820396614524e+02 1.656744118030001118e+02 4.102548418866948623e+02 +2.715865614680114390e+02 4.502034518163306984e+02 1.638171768481366257e+02 3.904443357014843059e+02 +2.503214911036295973e+02 4.671240454396022415e+02 1.325537217746011436e+02 4.173742425470049398e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..3ce5479cdf0f30db0390c01acb5ac9e487c86724 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/crossval_errors.txt @@ -0,0 +1,38 @@ +1.792424591998240890e+00 +3.095348507057420928e+00 +4.094406044190157523e-01 +7.302793957725348672e-01 +4.238650623157072528e-01 +4.165253741566223411e+00 +1.883763528268577736e-01 +8.053748309044221898e-01 +1.595531486099816265e-01 +1.697912170300142742e+00 +8.506116392505531643e-01 +3.122924043029775554e+00 +7.353410671850451052e-01 +8.858066508520040516e-01 +2.139081656124904196e+00 +3.931447770611433690e+00 +3.191792767739562997e+00 +3.623124474113926574e+00 +1.426124920876860269e+00 +1.237541469506047598e+00 +8.962695327020072655e-01 +1.530807091557143451e+00 +1.467530857533184907e+00 +1.121653280350474846e+00 +9.875670003547415421e-01 +2.657680183199366830e+00 +2.536812969979429511e+00 +2.003125655478974920e+00 +1.223910714981061520e+00 +9.496930900765404582e-01 +2.299513490038819441e+00 +4.825846860592630100e-03 +1.709008522450382817e+00 +4.024040354942465036e+00 +2.351401261463401671e+00 +2.432080444353363458e+00 +1.727733901591856469e+00 +1.847889070782418930e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/01.png new file mode 100644 index 0000000000000000000000000000000000000000..e9f3507f029d45e4861a98c323f9f3b1b95ba255 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5ef0c6fb8618c1cf867df09c6f6bcfef0e9c85b3e8164ecd2a865da34bf3027 +size 553707 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/02.png new file mode 100644 index 0000000000000000000000000000000000000000..d7145c3d436cb1842459cf55db663758c9e88c73 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ec6c1ba02df41d23ae4761ed875969c80e154dcd303e69640ba9f6de4fff7ba +size 608585 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..769ed6a898959ff816ae096a10846687fd933f85 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/corrs.txt @@ -0,0 +1,28 @@ +4.097915199999999913e+02 1.833357399999999870e+02 6.136951299999999492e+02 2.662372500000000173e+02 +3.972475499999999897e+02 1.823978700000000117e+02 5.958756700000000137e+02 2.663544800000000237e+02 +2.113397699999999872e+02 2.108091400000000135e+02 3.278392200000000116e+02 2.986441300000000183e+02 +3.930648499999999785e+02 2.325574699999999950e+02 5.847751600000000280e+02 3.389658099999999763e+02 +4.474652500000000259e+02 2.239596699999999885e+02 6.665026699999999664e+02 3.319995900000000120e+02 +2.731753100000000245e+02 2.438032200000000103e+02 4.108272499999999923e+02 3.488628499999999804e+02 +5.092126000000000090e+02 2.347558500000000095e+02 7.640383399999999483e+02 3.494393400000000156e+02 +1.692692700000000059e+02 2.209019199999999898e+02 1.972320100000000025e+02 3.109740199999999959e+02 +4.040470799999999940e+02 2.020930700000000115e+02 6.080679300000000467e+02 2.948421700000000101e+02 +1.872901500000000112e+02 1.894243800000000135e+02 3.020479000000000269e+02 2.720321700000000078e+02 +2.597743300000000204e+02 2.072379799999999932e+02 3.954828400000000101e+02 2.965339299999999980e+02 +2.937157100000000014e+02 2.189660000000000082e+02 4.434426500000000146e+02 3.158611599999999839e+02 +3.841037499999999909e+02 2.189624900000000025e+02 5.730518299999999954e+02 3.188016799999999762e+02 +4.726703999999999724e+02 2.223183999999999969e+02 7.063619899999999916e+02 3.270757899999999836e+02 +2.800920800000000099e+02 2.620916100000000029e+02 4.103583199999999920e+02 3.753575799999999845e+02 +4.936205800000000181e+02 2.189293600000000026e+02 7.385987199999999575e+02 3.250548099999999749e+02 +2.900925279496414078e+02 2.417676896878288915e+02 4.357913877379355085e+02 3.470031760735771513e+02 +3.921982208998750252e+02 2.025229845743038197e+02 5.874506446412149216e+02 2.962461188747516303e+02 +3.710130261040782216e+02 2.011337914729400893e+02 5.575755385475616777e+02 2.931013708648933971e+02 +4.272753467093090194e+02 2.219716879933959319e+02 6.358011452927852361e+02 3.269074119708694184e+02 +3.553846037137363396e+02 2.283047722814643521e+02 4.919289238417711658e+02 3.302140225398936764e+02 +3.828211674656698733e+02 2.285907842234612701e+02 5.292728064588377492e+02 3.313933030435904925e+02 +1.417557024889432000e+02 2.020893680292453212e+02 2.413374582723404842e+02 2.847164876775290736e+02 +1.349735569233231445e+02 1.758650718421810666e+02 1.759893184350587489e+02 2.526156470557064608e+02 +4.226454386301443265e+02 2.434042994222771199e+02 6.254660433374846207e+02 3.575691427733916044e+02 +4.158557961245051047e+02 2.434042994222771199e+02 6.163289032236831417e+02 3.587112852876167608e+02 +4.084757499227233097e+02 2.391499171974719218e+02 6.052881922528397354e+02 3.505595196120268611e+02 +3.981436852402288196e+02 2.373787061090442876e+02 5.900596253965038613e+02 3.471330920693512780e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..03e080f6eefd0d10531cfc16ecfaa0d16bb932a3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/crossval_errors.txt @@ -0,0 +1,28 @@ +2.795827317827273895e+00 +7.504814977516608421e-01 +2.358292759495422164e+00 +1.292287445027972881e+00 +2.551228559052934219e+00 +1.430468433774237103e-01 +1.822288632049907875e+00 +2.256415147302962332e+00 +4.115527888980367588e-01 +2.543433883252152139e+00 +9.771310714237624317e-01 +1.339901065391976509e+00 +5.753504995346042650e-01 +2.368484738426400060e+00 +8.237638325866948330e-01 +4.048545226987346202e-01 +2.386360497959467142e-01 +1.256918929674968322e+00 +7.735100630893360085e-01 +1.320422016321886716e+00 +7.705570551665671397e-01 +2.159805887096448274e+00 +2.559822007875808048e+00 +4.642055664900872181e+00 +9.192902876777044874e-01 +1.471532469196588755e+00 +2.965823050434360231e-01 +4.094884609445577084e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/01.png new file mode 100644 index 0000000000000000000000000000000000000000..88f10f54a598c0677cdd8c83f6cf1c982153de44 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc4ee14c70d6fb6bc44a0be7507f90b62c67773d93afc75a4f586d3cd01f6108 +size 917944 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/02.png new file mode 100644 index 0000000000000000000000000000000000000000..f9339231ab07fa31f4bb302d50e8ddc68ad7c591 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ffbab080ab2c02eb02c03e7128697c49b9c6cd13f8c99d72263379ca9c5ff83 +size 565060 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..4ada597fe23f7f3ffb92f7874c0f6399f709efca --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/corrs.txt @@ -0,0 +1,44 @@ +3.259731699999999819e+02 3.123227999999999724e+02 1.918973599999999919e+02 3.167674200000000155e+02 +4.166598599999999806e+02 3.097243000000000279e+02 2.435228899999999896e+02 3.129422900000000141e+02 +4.208318499999999744e+02 3.786820200000000227e+02 2.431711899999999957e+02 3.723449800000000209e+02 +3.264614700000000198e+02 3.788186200000000099e+02 1.913160300000000120e+02 3.726870000000000118e+02 +3.147381399999999871e+02 3.924176800000000185e+02 1.795927000000000078e+02 3.835896999999999935e+02 +4.289209500000000048e+02 3.902881199999999922e+02 2.436401199999999960e+02 3.814891799999999762e+02 +3.367984099999999899e+02 2.758902199999999993e+02 2.027032199999999875e+02 2.864261300000000006e+02 +4.080620700000000056e+02 2.735939700000000130e+02 2.440993699999999933e+02 2.813162800000000061e+02 +3.733191699999999855e+02 1.544756099999999890e+02 2.283057300000000112e+02 1.825394499999999880e+02 +3.746969199999999773e+02 1.133440200000000004e+02 2.360974199999999996e+02 1.435132199999999898e+02 +3.700463199999999802e+02 2.459151500000000112e+02 2.259659100000000080e+02 2.595727899999999977e+02 +3.808317799999999806e+02 2.440394200000000069e+02 2.326482000000000028e+02 2.576970600000000218e+02 +4.701611300000000142e+02 4.732083600000000274e+02 2.428049599999999941e+02 4.495536500000000046e+02 +5.705985799999999699e+02 5.078243400000000065e+02 1.030767199999999946e+02 4.655914399999999773e+02 +3.791130499999999870e+02 4.844928400000000011e+02 7.086453899999999351e+01 4.475752100000000269e+02 +2.952947100000000091e+02 4.967646399999999858e+02 1.516005400000000058e+01 4.524554299999999785e+02 +3.762500000000000000e+02 3.383506600000000049e+02 2.201671800000000019e+02 3.373724599999999896e+02 +4.003322800000000257e+02 4.686663500000000226e+02 1.770076899999999966e+02 4.441851300000000151e+02 +4.388674500000000194e+02 4.084516699999999787e+02 3.164035999999999831e+02 3.992031600000000253e+02 +4.069875999999999863e+02 2.193011199999999974e+02 2.671130600000000186e+02 2.328415200000000027e+02 +3.385181699999999978e+02 1.586970499999999902e+02 2.117699599999999975e+02 1.866726999999999919e+02 +3.301161099999999919e+02 2.897237499999999955e+02 1.969587899999999934e+02 2.989700900000000274e+02 +3.902007600000000025e+02 1.570547500000000127e+02 2.436632899999999893e+02 1.806637200000000121e+02 +3.722350200000000200e+02 1.264741499999999945e+02 2.335182900000000075e+02 1.539469799999999964e+02 +3.896441404448885351e+02 1.684838064643168423e+02 2.422721730666990538e+02 1.898239388128006908e+02 +4.044041567098074665e+02 1.982381249666137535e+02 2.636038469030721672e+02 2.123862861397337838e+02 +3.938612879491511194e+02 1.373237721272657268e+02 2.508868875006189967e+02 1.619286730267743053e+02 +3.727755504278383114e+02 1.384952019895608828e+02 2.318114483969391699e+02 1.654155812500275999e+02 +3.610612518048868083e+02 1.661409467397265303e+02 2.197098257397659609e+02 1.910546123033606705e+02 +3.111402209243742618e+02 4.100159911502355499e+02 1.790477254239340255e+02 3.982068908671177496e+02 +4.722420884743182796e+02 4.605639990494716471e+02 3.538258560724232211e+02 4.464215475977354117e+02 +4.315213324560577917e+02 4.083185007618920963e+02 2.414684863698230401e+02 3.977764028605943167e+02 +4.345945970612095266e+02 3.925680196604894832e+02 3.090551033939925105e+02 3.827093226322763257e+02 +7.209615812075318217e+02 5.667660675231618370e+02 7.917544657994415047e+01 5.005743734609804392e+02 +6.606642064527560478e+02 5.493360788255572515e+02 5.787154836054548923e+01 4.867895001037628049e+02 +6.609687386484872604e+02 5.669989461779662179e+02 5.745990518108130374e+01 4.981376862110609522e+02 +7.197434524246070850e+02 5.520768685871379375e+02 7.959145443812062126e+01 4.914611294519540934e+02 +2.786691872047798029e+02 4.833365024492061366e+02 3.202105196908468088e+01 4.448035153264768269e+02 +5.282652496148941736e+02 5.076146177348316542e+02 8.452211168355722748e+01 4.630119753430569176e+02 +3.647074203222595088e+02 5.080405495819479711e+02 2.595156529689131730e+01 4.593702833397409222e+02 +3.460325768818356664e+02 3.378550919150705454e+02 2.019291652358627971e+02 3.379567249800489890e+02 +3.951785592414013308e+02 3.276982555607603445e+02 2.298960110465138200e+02 3.285692522603899306e+02 +3.938679997118128995e+02 3.548923657997199825e+02 2.287225769565564519e+02 3.514512170145588925e+02 +3.506195352353951193e+02 3.240942168543921866e+02 2.048627504607562742e+02 3.270046734737800307e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..73e5fbcaa5d23546a11a3efa73431eb8c9fcec40 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/crossval_errors.txt @@ -0,0 +1,44 @@ +1.616226771291696451e+00 +1.754369060755345267e-02 +9.779284929175396934e-01 +6.576697555919837068e-01 +3.066270550408434770e-01 +1.612998834698591566e-01 +4.201040675301740968e-01 +8.155385261083895054e-01 +3.452953469835793321e+00 +8.020687774441207507e-01 +3.442170255501640352e-01 +8.943909552976530009e-01 +2.589561677756893943e+00 +2.425722120341224919e+00 +7.024860799446467352e-01 +4.603193840964353578e-01 +1.342611948197300231e+00 +1.205190859321352503e+00 +2.079690518854804271e+00 +2.371458487568577578e+00 +8.220301036726821442e-01 +6.511991988937511078e-01 +6.705641250402661901e-01 +1.972532867873080908e+00 +1.110895732960297089e+00 +1.907553063517399838e+00 +6.222314791948166945e-01 +8.391836101199339204e-01 +1.321392328191198562e+00 +1.251759274178413817e+00 +8.858787342394179865e-01 +1.676465379553625290e+00 +1.921688964633968766e+00 +2.509639115201356852e+00 +1.290504458418172407e+00 +8.185896321589704039e-01 +4.325984259078770044e+00 +9.623694333738862516e-01 +4.290098346935177220e-01 +1.840930682367824200e+00 +7.234322594695873354e-01 +5.028036646473726945e-01 +6.933085218108160364e-01 +2.703262904684516910e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..abdf1339cdec7cf4dbd289a9c5df8eb4a976cc4b Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/01.png new file mode 100644 index 0000000000000000000000000000000000000000..0b4dea47bd01e2c4104b4f6caf1244aed61877f7 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc091548ebd8d946965f745cd3688ba6bc4494dec3c1098ab504e5187284e393 +size 865043 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/02.png new file mode 100644 index 0000000000000000000000000000000000000000..1c94dc870ee672c88df769e6ec1ea0e313202499 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e56dea63ce9ecbf803681654d917245299884c81f46a865efdfd57615c1efea +size 581008 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a58a592c33b046940d869b241be0819a7499221 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/corrs.txt @@ -0,0 +1,29 @@ +2.538888800000000003e+02 4.090991300000000024e+02 9.853851099999999974e+01 5.648856100000000424e+02 +4.446807000000000016e+02 3.349336900000000128e+02 2.318818599999999890e+02 5.031479400000000055e+02 +4.724476900000000228e+02 4.192500200000000063e+02 2.566540499999999838e+02 5.702202499999999645e+02 +5.367548199999999952e+02 1.505413000000000068e+02 2.976936600000000226e+02 3.666731800000000021e+02 +6.043347899999999981e+02 2.343101900000000057e+02 3.508577399999999784e+02 4.230219299999999976e+02 +6.183655300000000210e+01 3.329600899999999797e+02 3.854480900000000076e+01 4.845407000000000153e+02 +4.316377100000000269e+02 3.135661099999999806e+02 1.981795400000000029e+02 4.685077499999999873e+02 +3.278198600000000056e+02 3.057405299999999784e+02 1.495681899999999871e+02 4.782585100000000011e+02 +5.888589600000000246e+02 3.169766000000000190e+02 3.457091599999999971e+02 4.851171800000000189e+02 +1.975148900000000083e+02 2.537826699999999960e+02 8.469426300000000651e+01 4.480296200000000226e+02 +4.739104947356379398e+02 3.122058046065272379e+02 2.327137685012712893e+02 4.752906051208748295e+02 +2.660098803478572336e+02 3.302083717741842861e+02 1.072749444745419396e+02 5.030318835114017020e+02 +2.654291523747069732e+02 3.964113607133099890e+02 1.093856939172994203e+02 5.476591574439880787e+02 +4.345871269560143446e+02 2.563627422044701802e+02 1.969468277504925879e+02 4.412749581862206014e+02 +4.596464188883149973e+02 2.459448343225024871e+02 2.062400571471764863e+02 4.357385662052174666e+02 +4.576754633430778654e+02 2.848008150714630347e+02 2.068332420022839813e+02 4.513591007230478453e+02 +4.796375394185772620e+02 2.369347518299899207e+02 2.117764491281796779e+02 4.303999025092500688e+02 +4.773850187954491275e+02 2.724119516442582380e+02 2.119741774132155001e+02 4.460204370270804475e+02 +5.055415265845509225e+02 2.225749328575479637e+02 2.179060259642903361e+02 4.260498802384619239e+02 +5.122990884539353829e+02 2.329928407395156569e+02 2.226515048051501822e+02 4.260498802384619239e+02 +4.408767197835879870e+02 4.253480886759214172e+02 2.329805935534139110e+02 5.728758021914462688e+02 +4.960907211153821663e+02 4.011672450316198706e+02 2.494870383241641889e+02 5.275034761405271411e+02 +1.316883174179574780e+02 3.062138219672158357e+02 6.982777892872408643e+01 4.722154469113214645e+02 +4.400374437744409875e+02 2.934501406239173775e+02 1.996894706097435801e+02 4.580269697371936672e+02 +2.510945103708573924e+02 3.324661813641471326e+02 9.511189071275245510e+01 5.078397482162217216e+02 +3.239784450498666502e+02 3.717378655408519421e+02 1.480045826113571650e+02 5.162190666884113170e+02 +4.210246903248170156e+02 1.588813566390604421e+02 1.849670454514543110e+02 4.016497702999811850e+02 +4.220810507607419026e+02 1.702372313252526226e+02 1.853919095185078447e+02 4.061816536818855639e+02 +6.723896457485888050e+02 2.257854111677311266e+02 2.913295806495771103e+02 4.216122987886223541e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..1004f9035428c12151d0b31185f15df88ad36642 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/crossval_errors.txt @@ -0,0 +1,29 @@ +3.031160677158884820e+00 +7.311932777312657450e-01 +3.773723759246768505e+00 +5.271593918761533715e+00 +7.408896420633541702e-01 +1.344973370763895837e-01 +7.609603320007013449e-01 +7.372587473637195465e-01 +3.574722318348853456e+00 +3.117516676048450286e+00 +1.765242167196014789e+00 +2.846135323394646033e+00 +2.075394853079954416e+00 +6.644829410992774132e-01 +4.763596528023891774e-01 +3.482295325426788768e+00 +2.091971221017461691e+00 +4.151898915801082723e-01 +3.310514074093689363e+00 +8.052084405660506761e+00 +8.410137898663668787e-01 +2.015056363810584728e+00 +5.752254255109191305e+00 +1.772150454723667945e+00 +3.570929124981857772e+00 +2.216294034905811061e-01 +5.777857679990370698e+00 +4.744082439652409278e+00 +5.579018075674269106e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/01.png new file mode 100644 index 0000000000000000000000000000000000000000..9c9bfbab79ef181a08f44691db648d3b462ad6b8 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38b11253931290a4bcd08b36042f785dbbceb74dbc2ea3c634ba733fd96e3a9d +size 621493 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/02.png new file mode 100644 index 0000000000000000000000000000000000000000..f8c500a2985fe355d566b920e99489140c42f97b --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31beee2ca02cff5dcf9ff514bc9f6b1a40d999b72f9990e33caa79f98f07d784 +size 727477 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..fb1a3e32f2abe80f7230b4eb6c0d52a17b241aa8 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/corrs.txt @@ -0,0 +1,30 @@ +4.985831099999999765e+02 1.960997399999999971e+02 2.980889399999999796e+02 1.978388699999999858e+02 +5.990689800000000105e+02 1.914491399999999999e+02 4.239487300000000118e+02 2.013558700000000101e+02 +4.162984799999999836e+02 2.040328700000000026e+02 6.599885299999999688e+01 1.785621299999999962e+02 +5.601637799999999743e+02 2.146042700000000139e+02 3.768953000000000202e+02 2.268169399999999882e+02 +5.480296200000000226e+02 1.210943000000000040e+02 3.666562400000000252e+02 1.103174800000000033e+02 +6.592826199999999517e+02 2.331184900000000084e+02 4.989541699999999764e+02 2.562348799999999756e+02 +6.768396000000000186e+02 2.438759400000000142e+02 5.178985900000000129e+02 2.699328499999999735e+02 +7.861383700000000090e+02 2.580815699999999993e+02 6.324438099999999849e+02 2.911928800000000024e+02 +5.068776199999999790e+02 2.486631399999999985e+02 3.205019399999999905e+02 2.688583800000000110e+02 +5.223340899999999465e+02 1.412895499999999913e+02 3.293438899999999876e+02 1.326423000000000059e+02 +6.823205199999999877e+02 2.695918699999999717e+02 5.282742500000000518e+02 3.020578800000000115e+02 +4.333758100000000013e+02 2.568404199999999946e+02 2.319244199999999978e+02 2.734003999999999905e+02 +4.219256799999999998e+02 2.166940700000000106e+02 7.549474999999999625e+01 1.996641199999999969e+02 +5.458613199999999779e+02 2.142525699999999915e+02 3.584896800000000212e+02 2.266997000000000071e+02 +5.531297799999999825e+02 1.849442500000000109e+02 3.698613100000000031e+02 1.901229199999999935e+02 +6.479109899999999698e+02 2.222157899999999984e+02 4.880514800000000264e+02 2.449804900000000032e+02 +7.624572500000000446e+02 2.641777000000000157e+02 6.088799199999999701e+02 2.968200800000000186e+02 +3.819970700000000079e+02 2.591869899999999802e+02 1.369325900000000047e+02 2.699572099999999750e+02 +5.965183347360991775e+02 2.029132062807596526e+02 4.219686410382408894e+02 2.159518907465023574e+02 +6.023373115966885507e+02 2.511254858522926554e+02 4.348509607094081275e+02 2.762665105431891561e+02 +6.259073538232618148e+02 1.517615823481111761e+02 4.452224995348522611e+02 1.489270616307932187e+02 +6.724091600182804314e+02 2.338288746764192467e+02 5.136538292206985261e+02 2.581492043719323419e+02 +6.047190589338688369e+02 1.805424661357062064e+02 4.299513950707617482e+02 1.875140898765805559e+02 +6.379103839944622223e+02 2.238172317210368192e+02 4.564564387357852979e+02 2.401400461390186081e+02 +3.975935259747074042e+02 2.453914551063811587e+02 1.686753387917444229e+02 2.514382534722063269e+02 +3.294271181369521173e+02 2.496277855369761198e+02 3.153879244165455020e+01 2.550952280415420432e+02 +3.764118738217325699e+02 2.519385112263915687e+02 1.260106354828275244e+02 2.593616983724337501e+02 +4.084209182343442990e+02 2.160644250742822692e+02 5.376696716789734865e+01 1.973480266865232977e+02 +4.087249287016336439e+02 2.210805977845561756e+02 5.349111045331716241e+01 2.047961579801880703e+02 +4.195173002904047053e+02 2.093761947939171364e+02 7.169765361560885708e+01 1.882447551053774646e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..6402176f76ec6c4f2f202b5b6e28e2ef8798c882 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/crossval_errors.txt @@ -0,0 +1,30 @@ +9.966072664324915342e-01 +1.544877954599090109e-01 +1.216055812894111021e+00 +1.471117641364499684e-01 +2.490423304593254894e+00 +1.247993306307343175e+00 +1.160910691169436904e+00 +1.703539622639616002e+00 +3.285086051858440825e+00 +2.630431498484563591e+00 +1.403033950025595278e+00 +9.816643371412056007e-01 +9.521494556258358957e-03 +2.043210232592532982e+00 +4.043898815586809969e-01 +1.625646765139507899e+00 +2.487338704638109910e+00 +1.876527263518038113e+00 +7.103605442748321952e-01 +1.457778219094958194e+00 +3.647273024353496762e+00 +8.317520446923173383e-01 +7.234716387740043331e-01 +1.089039901934546029e+00 +5.020792994495852035e+00 +5.186133597713616261e+00 +2.199157286183403226e+00 +7.252937165707351586e-02 +7.685001472076117279e-01 +4.437382505674167255e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8f640ed9c57d69f83944dd400ae8326e21693a9b Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/01.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd82688419a849ed4652fc5bc94466b92e13579d --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/01.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d87f4222d67b8dd076e4a3c88c4edce0710f8c83824e4102bc72881f7924ffd +size 684132 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/02.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9d389a772e942aa9363e946c97771b452a306b46 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/02.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c519c699a7dbab7a5cd750b8ac8a34eafc3ee09ae0ebe8d5e3363a0078a9e20 +size 252451 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e1be7157a9f46fe2e1a51f701b4fbef6d2761be --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/corrs.txt @@ -0,0 +1,64 @@ +1.403517121234611750e+03 5.466632954662579778e+02 9.157400829009418430e+02 4.245773770158315870e+02 +1.446026508108539019e+03 5.782981880235990957e+02 9.655740039421418714e+02 4.516423513744143179e+02 +1.398574169272527115e+03 5.654465129221791813e+02 9.247617410204694579e+02 4.400430766493074088e+02 +1.427879044896787491e+03 5.599802860549647221e+02 9.514510915071646195e+02 4.365780725750759075e+02 +1.460029932428406937e+03 6.078546003356241272e+02 9.741354618952484543e+02 4.774099392736268896e+02 +1.427409688874427957e+03 5.487157415183387457e+02 9.496363418761178536e+02 4.273228494567376856e+02 +1.463550102596102306e+03 5.620923881555819435e+02 9.792167608621792851e+02 4.391187220585413229e+02 +8.721037604001448926e+02 5.140213980344973379e+02 4.820933786500399947e+02 3.755344490808192290e+02 +8.900586507513088463e+02 5.001811700554750928e+02 4.969637109152347989e+02 3.650377439524464194e+02 +8.530266894020330710e+02 4.983108689772288358e+02 4.663483209574807802e+02 3.606641168156244248e+02 +9.132503841215623197e+02 5.286097464448180290e+02 5.249549245908956436e+02 3.882179677776030076e+02 +8.410567625012570261e+02 4.990589894085273386e+02 4.541021649743792068e+02 3.624135676703532454e+02 +8.679890980280031272e+02 5.708785508131833240e+02 4.768450260858535898e+02 4.240817102995434311e+02 +9.024026378677340290e+02 5.263653851509225206e+02 5.078977787572898137e+02 3.869058796365563921e+02 +9.076394808868235486e+02 4.766153764695723112e+02 5.131461313214762185e+02 3.453564218367474723e+02 +8.590116528524209798e+02 4.552939441775650948e+02 4.711593108079849799e+02 3.248003742936840581e+02 +9.293349733944800164e+02 4.908296646642438645e+02 5.293285517277175813e+02 3.576025778198489888e+02 +8.691112786749508814e+02 4.201322839065355765e+02 4.807812905089933224e+02 2.946223470496121308e+02 +7.972917172702948392e+02 4.710044732348335401e+02 4.173636970250743161e+02 3.348597167083745489e+02 +8.002841989954887367e+02 4.840965807825573393e+02 4.199878733071674901e+02 3.471058726914761792e+02 +8.317052571100257410e+02 4.119029591622520456e+02 4.475417242691461297e+02 2.858750927759681417e+02 +8.638744356558614754e+02 3.348465547385064838e+02 4.768450260858535330e+02 2.189585975825915227e+02 +9.970361884849378384e+02 4.624331693670503114e+02 5.835681049941797482e+02 3.363546786606245860e+02 +1.067836733301809318e+03 4.949436236196954155e+02 6.476321625989970698e+02 3.691316383654148581e+02 +1.052665187983908254e+03 5.606869866639333395e+02 6.349683372585099050e+02 4.220217324345082375e+02 +9.884432769467641720e+02 5.193298235340000701e+02 5.794862958419262213e+02 3.850917673052189230e+02 +9.930452410428733856e+02 5.395784655568805874e+02 5.831225687213482161e+02 4.016822623175819444e+02 +9.543887426355561274e+02 5.572960273269010258e+02 5.592595279501410914e+02 4.150910185604507205e+02 +9.210245029387642717e+02 5.766242765305596549e+02 5.303966119697288377e+02 4.305451782979943118e+02 +1.004089954873535362e+03 5.761640801209487108e+02 5.908496485901200685e+02 4.332723829575608079e+02 +9.631324744181634969e+02 5.782349639641978456e+02 5.683502101486963056e+02 4.337269170674885572e+02 +9.033069411687438333e+02 5.439503314481843290e+02 5.131243157924741354e+02 4.007731940977264458e+02 +9.132011639753786767e+02 5.540746524596245308e+02 5.217604638811014865e+02 4.114547456810286121e+02 +9.210245029387642717e+02 5.432600368337679129e+02 5.313056801895843364e+02 4.012277282076541951e+02 +7.784212227798731192e+02 6.446639064584231846e+02 3.978784008196220157e+02 4.852386616555929777e+02 +7.823088294167193908e+02 6.252258732741913718e+02 4.013899429846774183e+02 4.676809508303159646e+02 +7.698684881788110488e+02 6.310572832294609498e+02 3.841833863759060250e+02 4.718948014283824932e+02 +7.578169076045874135e+02 6.641019396426549974e+02 3.715418345817065529e+02 4.999871387488257142e+02 +7.480978910124714503e+02 6.252258732741913718e+02 3.631141333855736093e+02 4.662763339642938263e+02 +8.204073744578138303e+02 6.664345036247627831e+02 4.347495935527038000e+02 5.059567604294198873e+02 +8.145759645025442524e+02 6.116192500452291370e+02 4.294822803051206392e+02 4.585509412011718950e+02 +7.434327630482557652e+02 6.104529680541752441e+02 3.567933574884738164e+02 4.525813195205777220e+02 +8.122434005204364666e+02 6.843174941542561101e+02 4.565211549760472280e+02 5.266748592032467968e+02 +1.353822027092758844e+03 5.167364512312507259e+02 8.780895812280047039e+02 3.982440569932516041e+02 +1.308385355407982388e+03 4.634288600420246667e+02 8.399007632469630380e+02 3.516356570951298863e+02 +1.358473024981751678e+03 4.566312477427273961e+02 8.801944767072747027e+02 3.483279641991341578e+02 +1.334860266468403324e+03 4.949125380598226229e+02 8.585441232062116796e+02 3.783978996172771758e+02 +1.306596510066061910e+03 4.838216969399165350e+02 8.377958677676930392e+02 3.684748209292899901e+02 +1.285130365963017994e+03 4.920503855127500970e+02 8.200546058709886665e+02 3.753909060754629081e+02 +1.325558270690417885e+03 3.636112899628698756e+02 8.537329335393088741e+02 2.716496288828693650e+02 +7.619232704200128410e+02 7.187420337489038502e+02 4.538960459578931363e+02 5.910054474930707329e+02 +7.531660864475670678e+02 7.187420337489038502e+02 4.429703050095473600e+02 5.928264043177950953e+02 +7.573825083602262112e+02 7.550681302271973436e+02 4.721056142051361917e+02 6.602018068325943432e+02 +7.622476105671405548e+02 7.427432046363477411e+02 4.915291536688620795e+02 6.462411378430413151e+02 +7.732751755694796429e+02 7.424188644892201410e+02 5.060968082666565238e+02 6.450271666265584827e+02 +7.658153521855443842e+02 7.657713550824089452e+02 5.127736499573122728e+02 6.984419001518047025e+02 +7.700317740982034138e+02 7.523016972315615476e+02 5.413019735446596314e+02 6.853738526981995847e+02 +7.800863186591597014e+02 7.535990578200720620e+02 6.244590018737362698e+02 7.363606437904802533e+02 +7.917625639557541035e+02 7.516530169373063472e+02 7.913800441401308490e+02 8.219456145525227839e+02 +8.083039114592627357e+02 7.513286767901787471e+02 8.247642525934097648e+02 8.207316433360399515e+02 +5.656885711194121313e+02 7.433783793471844774e+02 1.960135294447161698e+02 6.822218008636290278e+02 +5.542464120866965231e+02 7.429296672282544023e+02 1.741398289424236907e+02 6.815968379921349651e+02 +5.486375106000713231e+02 7.507821293095298643e+02 1.544534984903604595e+02 7.325313120189016445e+02 +5.349517909727056804e+02 7.501090611311348084e+02 1.244552806586450799e+02 7.328437934546486758e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..b8231162755761c1514e8aeba67f9e97753bb132 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/crossval_errors.txt @@ -0,0 +1,64 @@ +2.109604541263442046e-01 +6.343754191545079024e-01 +9.604626447087906138e-01 +1.112371853549305678e+00 +1.719472206398944270e-01 +1.389983618486540617e+00 +1.772333438562186669e-01 +1.043528763362485901e+00 +1.662313702257255166e+00 +6.873164968361269445e-01 +9.132861587639566903e-01 +2.079079424822051880e+00 +3.929539696712000540e-02 +5.858750028825535777e-02 +1.725362960485958386e+00 +1.498706429453000144e+00 +7.361902992256655898e-01 +1.155112462959435371e+00 +6.905814524075082339e-01 +1.952610849768650514e+00 +3.943618478419495532e-01 +5.791753657862884985e+00 +2.647108694363052184e+00 +3.032848331037355738e+00 +8.897047696970773467e-01 +3.825412735926300156e-01 +1.052341249583254479e+00 +3.080588116791681541e-01 +2.719112836909035394e-02 +1.413006610379454075e+00 +2.183674009730531551e-01 +7.847240469606918678e-01 +1.378023295799093217e+00 +3.976818240761115231e-01 +9.738282951090618256e-02 +9.351098173156836557e-01 +6.955525944570016827e-01 +5.584256764423154440e-01 +2.150741620512164332e-01 +3.337131805125549688e-02 +1.575918828991396126e+00 +2.025164692696442614e+00 +5.579441790759137376e+00 +8.872051410423215101e-01 +7.742449638499517839e-01 +4.386510970436738321e-01 +1.772402648524209701e+00 +7.499468297338771627e-01 +8.822004627697077606e-01 +1.928876693608766679e+00 +5.351425384933634621e-01 +3.008437227703353312e+00 +5.620997066040206214e+00 +5.690692638023513439e-01 +8.616311524736414151e-01 +5.594131258364883230e+00 +3.248873051656249622e+00 +6.803024422893112766e-02 +4.311064498905580855e+00 +2.259541286881223687e+00 +1.085469088010938998e+00 +1.497713528125711679e+00 +3.860337864713685008e+00 +2.744465939206676364e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8f640ed9c57d69f83944dd400ae8326e21693a9b Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/01.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..10f1464582f2d63d89f59e7f0fa7ab13252c0bf7 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/01.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db2f587ddd146911e65c48b5c9fa81af5e060f51f36a3b69bbe0359dd9263dd3 +size 297557 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/02.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd82688419a849ed4652fc5bc94466b92e13579d --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/02.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d87f4222d67b8dd076e4a3c88c4edce0710f8c83824e4102bc72881f7924ffd +size 684132 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..df5e6ba412a763a0c7d1231c55fd8c3dc602ed9c --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/corrs.txt @@ -0,0 +1,50 @@ +2.987967877092588651e+02 4.223206853946798560e+02 9.225885599702967284e+02 5.269050686756181676e+02 +3.444547977198159288e+02 4.320679010149111150e+02 9.675792986429532903e+02 5.294045541574324716e+02 +3.234213324340537383e+02 4.510493209069404656e+02 9.482975534975291794e+02 5.483292299483117631e+02 +2.987967877092588651e+02 4.751608542833020010e+02 9.258021841612007847e+02 5.761806396028134714e+02 +2.885365607405943820e+02 4.223206853946798560e+02 9.133047567521296060e+02 5.286904154483427192e+02 +6.066166940245287833e+02 4.446378419628800884e+02 1.350580024253208649e+03 4.783373551138423068e+02 +6.152647798956927545e+02 4.460215357022663056e+02 1.362009688343699736e+03 4.783373551138423068e+02 +6.308313344637878117e+02 4.844190369702341741e+02 1.384869016524681683e+03 5.190555334362164785e+02 +6.045411534154493438e+02 4.795761088823824139e+02 1.348794139239069409e+03 5.172696484220772390e+02 +5.723702739747195665e+02 4.771546448384565338e+02 1.301646774865794214e+03 5.208414184503557181e+02 +7.321326865086233511e+02 5.189600829123442054e+02 1.426648962845904407e+03 5.502526159436411035e+02 +6.952341867916571800e+02 5.302602484506650171e+02 1.399037887857698024e+03 5.654962302600466728e+02 +7.623433331518893965e+02 5.261091672325063655e+02 1.467777959963753347e+03 5.508278466725621456e+02 +7.457390082792546764e+02 5.480176514394549940e+02 1.446494422993677745e+03 5.775760755673869653e+02 +7.215243678399956480e+02 4.693777239176708918e+02 1.427511808939285856e+03 4.950304659672286789e+02 +2.137836007884084779e+02 3.473788774328126010e+02 8.906646694904027299e+02 4.566001456861312136e+02 +1.820175798862397301e+02 3.409722849819551129e+02 8.584382324718022801e+02 4.551989962505399490e+02 +2.153852489011228499e+02 3.943605554057680251e+02 8.946345928912448926e+02 5.000357781894623486e+02 +1.798820490692872340e+02 3.858184321379579842e+02 8.593723320955298277e+02 4.977005291301434795e+02 +4.348808888852079235e+01 4.488766776777111431e+02 7.668412146435323393e+02 5.669073554834446895e+02 +4.447047982195066851e+01 4.129908360862074801e+02 7.672572798517172714e+02 5.381466887008497224e+02 +1.941838714601171318e+02 4.986112097318912788e+02 8.817404582750726831e+02 5.998069580836389605e+02 +1.797534379770573310e+02 4.966611511530994107e+02 8.677929218487254275e+02 5.998069580836389605e+02 +1.746832856721984513e+02 5.099215494888841249e+02 8.633550693494330517e+02 6.115355682603401419e+02 +1.731232388091649455e+02 4.650702021766711596e+02 8.595511957786111452e+02 5.703269379097686169e+02 +2.035441526383181383e+02 4.814506942385228285e+02 8.899821843451870791e+02 5.823725375507049193e+02 +1.523935219787211395e+02 4.133482128927837493e+02 8.355361283176038114e+02 5.236464652993014397e+02 +1.963543895283957568e+02 4.074867638861604746e+02 8.762130226334011240e+02 5.146416566339721612e+02 +4.434197361592768516e+02 4.544329741768374902e+02 1.128685088353774518e+03 5.237216515323132171e+02 +4.355340789320897557e+02 4.587908373813356206e+02 1.123810639541448154e+03 5.276676339041963502e+02 +4.498527723182979230e+02 4.556780779495512661e+02 1.134952236826765329e+03 5.241858847525347755e+02 +4.639639484090537849e+02 4.581682854949787611e+02 1.151200399534519192e+03 5.248822345828671132e+02 +4.556632565909620780e+02 4.390766943133678524e+02 1.140987268689645362e+03 5.058486725537838993e+02 +4.550407047046052185e+02 4.463397996541980888e+02 1.140755152079534582e+03 5.132764040773286069e+02 +4.587760160227464894e+02 4.463397996541980888e+02 1.145397484281749939e+03 5.135085206874392725e+02 +4.465324955910612630e+02 4.839004301310629899e+02 1.132863187335768316e+03 5.522719945759380380e+02 +4.676992597271950558e+02 4.677140810857841871e+02 1.157699664617620783e+03 5.323099661064117072e+02 +1.889113811606345052e+02 3.089447292260057338e+02 8.693428810072523447e+02 4.208949625426010357e+02 +1.834191741745515571e+02 2.252941920533577900e+02 8.635352561421038899e+02 3.351208106881026652e+02 +8.742537370746526904e+01 5.356845514731066942e+02 7.785539052876317783e+02 6.439940125048218533e+02 +3.263139680571970302e+01 5.831726647879529537e+02 7.599659899561384009e+02 6.915572076177609233e+02 +9.107830550091497912e+01 5.159587197884783336e+02 7.823808290323510164e+02 6.254060971733283623e+02 +5.729830842288644135e+02 3.426182336336015055e+02 1.326426310886490228e+03 3.645918379055057130e+02 +5.953251222022986440e+02 4.159330672469500314e+02 1.346409458154053709e+03 4.449439870747970645e+02 +5.432431460913823003e+02 4.486905092352639031e+02 1.285043155030771914e+03 4.921488356311676284e+02 +5.521984180018566803e+02 4.625947472015266158e+02 1.291337134838287966e+03 5.055235427221392683e+02 +6.269042389392343466e+02 4.493975043860908158e+02 1.379059478405543359e+03 4.799542497541052057e+02 +3.778310515211186384e+02 4.781719669646308262e+02 1.051826775237077982e+03 5.583517754714107468e+02 +3.808586594790852473e+02 4.958330133861024933e+02 1.057766404338827897e+03 5.745507639307290901e+02 +3.427612593413391551e+02 5.003744253230523782e+02 1.032387989085895697e+03 5.815703255964338041e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..a14b3cdd5a55383687b953a78e61923fd2c5f038 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/crossval_errors.txt @@ -0,0 +1,50 @@ +1.205625952224563369e+00 +4.939422348912164695e-01 +1.408368284957418615e+00 +4.354895719375530660e+00 +5.845732642066374662e-01 +5.671421046587751258e-01 +3.241250475917350493e-01 +1.176934650482001077e+00 +9.602228893504475282e-01 +1.539524729734165653e+00 +5.977619023052603842e-01 +3.805437463550794264e-01 +2.688021655154999578e+00 +1.321323261517616254e+00 +2.544837266919131480e+00 +1.340388986541366956e+00 +1.049147679116919640e+00 +6.487360989178689863e-01 +1.221299429389517410e+00 +3.649503550987649003e+00 +5.805746815258191695e+00 +1.008171333833599137e+00 +1.074115212606320480e+00 +2.646741835059085446e+00 +3.351911533805484567e+00 +1.617215307823174131e+00 +7.151241593030583488e+00 +2.331406646868857457e+00 +7.210754257250947541e-01 +8.528073969880842764e-01 +8.421739174470552758e-01 +1.636074672456483636e+00 +1.697227528685974207e-01 +1.095537504541505142e-01 +1.492383094478506145e+00 +3.682014845518886692e-01 +3.537468933850831943e-01 +2.031705423621828210e-02 +5.066690777812540070e+00 +3.293129568230977355e+00 +2.760852141563832074e+00 +5.163656053651205724e+00 +1.366074023533330628e+00 +2.186481125400338232e+00 +4.904450483949484019e-01 +2.338855629843699102e+00 +2.778060616070257560e-01 +4.025423249494965994e-01 +7.730484211806706862e-02 +1.283781951389742604e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/01.png new file mode 100644 index 0000000000000000000000000000000000000000..4beb4f11e66b9f66ec03e273cacef6412a9ff615 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc924ed57a48a75b7f073db4741b8b7e2d803d460f1a862b835f4d99311651fb +size 691893 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/02.png new file mode 100644 index 0000000000000000000000000000000000000000..5a3959270387f20c9d983f0ca3575b5fe8c42a85 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e26283b4d7ef459bf967312a5a9e8e5b32fa51c0efea20fcc5531b271e88266 +size 496664 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..6871c86a0e69b76610a82bb89ed95538d7bf2414 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/corrs.txt @@ -0,0 +1,22 @@ +3.925862300000000005e+02 2.545839400000000126e+02 2.499319999999999879e+02 3.251010099999999738e+02 +6.381495099999999638e+02 2.215317599999999914e+02 4.726627899999999727e+02 3.199427400000000148e+02 +5.014558400000000233e+02 2.289949200000000076e+02 3.575410899999999970e+02 3.167764099999999985e+02 +2.755296599999999785e+02 2.137739599999999882e+02 9.343351699999999482e+01 2.898795000000000073e+02 +2.447049199999999871e+02 3.068080900000000000e+01 7.066881899999999916e+01 1.679281799999999976e+02 +7.458263799999999719e+02 2.498548299999999927e+02 5.190902899999999818e+02 3.398734400000000164e+02 +5.115282199999999762e+02 6.731880400000000009e+01 3.293148400000000038e+02 2.256605500000000006e+02 +6.698830699999999752e+02 2.581891200000000026e+02 4.624721400000000244e+02 3.405574700000000234e+02 +6.418331899999999450e+02 3.179037400000000275e+02 4.336016300000000001e+02 3.712456199999999740e+02 +5.147312200000000075e+02 2.215231100000000026e+02 3.794336299999999937e+02 3.116288599999999747e+02 +7.477903000000000020e+02 5.901460699999999804e+01 5.822250800000000481e+02 2.456988000000000056e+02 +3.554008000000000038e+02 2.895215299999999843e+02 1.787855700000000070e+02 3.432538400000000252e+02 +6.026838400000000320e+02 2.625073800000000119e+02 4.086288700000000063e+02 3.397562100000000100e+02 +4.577935299999999756e+02 1.795010299999999859e+02 3.282694299999999998e+02 2.821538499999999772e+02 +5.587376199999999926e+02 2.643346999999999980e+02 3.724694900000000075e+02 3.408016200000000140e+02 +4.964148099999999886e+02 2.184439199999999914e+02 3.535551600000000008e+02 3.105630400000000009e+02 +3.180709299999999757e+02 1.195023400000000038e+02 1.410334100000000035e+02 2.347362300000000062e+02 +3.730920300000000225e+02 1.536518200000000078e+01 2.000334100000000035e+02 1.823751499999999908e+02 +4.079947199999999725e+02 9.514830000000000609e+01 2.322432599999999923e+02 2.305287199999999928e+02 +4.558833500000000072e+02 3.438276700000000119e+01 2.814226300000000265e+02 2.027807699999999897e+02 +4.566216911046785185e+02 2.287329680078115075e+02 3.273232766414439538e+02 3.111077682432915594e+02 +3.965061988826488459e+02 2.726309488309927360e+02 2.496704751160283706e+02 3.352331852304391191e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..2180b0f470557d9e7eb64fd1ac1ce6aef0c09fa1 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/crossval_errors.txt @@ -0,0 +1,22 @@ +2.223769488158787055e+00 +2.041660990806244691e-01 +2.591624711512738433e+00 +2.980969265769438881e+00 +4.279999916905446788e+00 +3.607338159981161585e-01 +5.093382919492810856e+00 +1.723610403284629444e-01 +1.385046455528339049e+00 +1.359466349711395239e+00 +1.376832687544708378e+01 +2.418395247191616804e+00 +2.410329102105791821e-01 +3.259756579203542781e+00 +5.259098693002325575e+00 +3.273444131738968643e+00 +2.662659910154434151e+00 +5.280286776725191977e+00 +4.117411483340728540e-02 +5.288077931743116224e-03 +3.267937158693952515e+00 +1.394501496347928304e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/01.png new file mode 100644 index 0000000000000000000000000000000000000000..d1b4ece1baf47d952aa3746bfa2335b5e1fa1684 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53886fb0c41be6ded1d7d3918fe3948e6dba991d72599710b2570641b952f22c +size 990530 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/02.png new file mode 100644 index 0000000000000000000000000000000000000000..7deb1d6677c76ec50ab4c324ada4838947c490b0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ecab06081c7b8611ceca3bf64630b13a85cc709389bf925696ae456c30c3553 +size 759959 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..fea259ddf089697453947ba96caebe9fe232bc8c --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/corrs.txt @@ -0,0 +1,25 @@ +1.768720399999999984e+02 3.979225299999999947e+02 4.708233799999999860e+02 1.959799399999999991e+02 +1.766017400000000066e+02 4.222839299999999980e+02 4.707061499999999796e+02 2.611706399999999917e+02 +1.758983399999999904e+02 4.596776399999999967e+02 4.682442500000000223e+02 3.509988299999999981e+02 +3.112423999999999751e+02 4.943585600000000113e+02 7.164331600000000435e+02 2.956984299999999735e+02 +3.310578300000000240e+02 4.842376600000000053e+02 8.092814100000000508e+02 2.980654000000000110e+02 +3.111968499999999835e+02 5.296323700000000372e+02 7.189657799999999952e+02 4.042986500000000092e+02 +3.314909200000000169e+02 5.161572499999999764e+02 8.047335199999999986e+02 3.949218900000000190e+02 +2.031905600000000049e+02 3.987170399999999972e+02 5.484593099999999595e+02 2.075618299999999863e+02 +2.019465500000000020e+02 4.416759700000000066e+02 5.445906099999999697e+02 3.245060100000000034e+02 +3.006914100000000190e+02 4.595604099999999903e+02 7.968062599999999520e+02 3.468510499999999865e+02 +8.590342699999999354e+01 3.966785199999999918e+02 2.417377799999999866e+02 1.957473699999999894e+02 +6.083863300000000152e+01 3.960206699999999955e+02 1.838174500000000080e+02 1.984437399999999911e+02 +1.653182099999999934e+02 3.722618600000000129e+02 4.425554599999999823e+02 1.297882600000000082e+02 +1.193828900000000033e+02 4.585934100000000058e+02 3.248836699999999951e+02 3.518213600000000270e+02 +1.579033799999999985e+02 4.215282700000000204e+02 4.204320900000000165e+02 2.576441300000000183e+02 +1.842219000000000051e+02 4.180635300000000143e+02 4.893462400000000230e+02 2.474543400000000020e+02 +1.923110000000000070e+02 4.434994399999999928e+02 5.161926600000000462e+02 3.249730400000000259e+02 +3.200348999999999933e+02 4.505334399999999846e+02 8.501474100000000362e+02 3.175427300000000059e+02 +3.200348999999999933e+02 4.593259400000000028e+02 8.503818800000000238e+02 3.455614899999999921e+02 +3.247481557361764999e+02 4.943084603998072453e+02 7.579714323690594711e+02 3.001010955510471945e+02 +3.436647225311791090e+02 4.846647204651000607e+02 8.483393450439383514e+02 3.027986451831331465e+02 +3.102825458341156377e+02 4.572171529586256611e+02 8.240613983551650108e+02 3.405643400323361334e+02 +3.556460695439168944e+02 4.606928978841127105e+02 9.578515049123923291e+02 3.496653086346104260e+02 +3.555344823310228435e+02 4.505384615107533932e+02 9.556024229066407543e+02 3.187404310555252778e+02 +1.521515380648792188e+02 4.448291019826118600e+02 4.073903440174262300e+02 3.232320381043676889e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..b0ba1873dfcfe6c7e870018858c8694b6d36fa0d --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/crossval_errors.txt @@ -0,0 +1,25 @@ +2.299443554484710273e+00 +1.885525491243613239e+00 +6.086572087327464686e+00 +2.034851019204923617e+00 +8.692972926988611349e+00 +6.310085817588906743e+00 +5.782550545440895640e+00 +9.802227095767416243e-01 +2.634379550856073671e+00 +3.004014133133211306e+00 +7.441118090272435204e+00 +6.820322738563534770e+00 +2.857044245627155199e+00 +3.251653039626808406e-01 +2.033689779695354805e+00 +2.800776884471881445e+00 +2.064570116053137561e+00 +2.820266171514703490e+00 +1.978516375695342111e+00 +4.965627338948699787e+00 +5.329964473568547412e+00 +2.740790105599770765e+00 +6.842949138841069257e+00 +2.904103786325705983e+00 +2.871389478412832652e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/01.png new file mode 100644 index 0000000000000000000000000000000000000000..0a577f2ae6655ebf32fa35fc811ea6030ac40d4c --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7713564379e217f0cfee10f12ee5a9681679afe355c01c60aecf7e1f72cb095 +size 439590 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/02.png new file mode 100644 index 0000000000000000000000000000000000000000..160568f2d168e6271365d7eb44c612dd60cade56 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83a8c09a57a979df54e364175754a043301ac7726ce5f7e2b48ebcb77b8976cd +size 732859 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..a428ea7b70331b541856e006a7ff8c48aa1276d0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/corrs.txt @@ -0,0 +1,24 @@ +5.728948299999999563e+02 1.939825099999999907e+02 6.367427099999999882e+02 3.378577599999999848e+02 +4.687543299999999817e+02 1.892049900000000093e+02 5.377120499999999765e+02 3.363918300000000272e+02 +4.121361800000000244e+02 1.856385400000000061e+02 4.508359600000000000e+02 3.342429000000000201e+02 +5.869821899999999459e+02 9.877109900000000664e+01 5.553046600000000126e+02 2.652346800000000258e+02 +6.642043499999999767e+02 9.188649300000000153e+01 5.798319999999999936e+02 2.647494899999999802e+02 +3.670359500000000139e+02 2.320079399999999907e+02 3.654839000000000055e+02 3.947103700000000117e+02 +5.437276100000000270e+02 2.619335500000000252e+02 6.116904100000000426e+02 3.971625900000000229e+02 +4.302309099999999944e+02 2.261345200000000091e+02 4.775478600000000142e+02 3.763801399999999830e+02 +6.861916800000000194e+02 2.373991800000000012e+02 7.011787799999999606e+02 3.622585000000000264e+02 +5.732465300000000070e+02 2.929346600000000080e+02 6.364006900000000542e+02 4.162532899999999927e+02 +3.287964600000000246e+02 2.068985700000000065e+02 2.830723600000000033e+02 3.759810600000000136e+02 +5.876089626368114978e+02 8.289871306694442410e+01 5.532441718240431783e+02 2.492208339372825208e+02 +4.687863450413139503e+02 2.336412410594688822e+02 5.361196246855870413e+02 3.800426423573028387e+02 +6.047161665443178435e+02 2.697475998962042922e+02 6.569702977946423061e+02 3.925444361272050742e+02 +3.965736273678431303e+02 2.421368549034066291e+02 4.186027632485056529e+02 4.033793240611204283e+02 +3.965736273678431303e+02 2.187739168325778110e+02 4.186027632485056529e+02 3.742084719313484129e+02 +4.893174118308301104e+02 2.541723078489850991e+02 5.577894005534176358e+02 3.983786065531595000e+02 +5.751068647856137659e+02 1.106118303041440356e+02 5.480509012947979954e+02 2.721983718604265050e+02 +5.925572401533436278e+02 1.051961965693313061e+02 5.627727784119069838e+02 2.706068175774958036e+02 +6.085032728169588836e+02 1.018866426202791047e+02 5.751073241046200337e+02 2.682194861530997514e+02 +5.723990479182074296e+02 8.533887287501804053e+01 5.400931298801444882e+02 2.495187233286639525e+02 +5.648773343976341721e+02 8.714408411995560755e+01 5.257691413337681752e+02 2.491208347579312772e+02 +5.627712546118737009e+02 1.067005392734459406e+02 5.261670299045008505e+02 2.666279318701690499e+02 +6.623587416242630752e+02 8.052497622185117621e+01 5.790862098119467873e+02 2.546912747481887322e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..1149a1de18179182c045e56198b0d3fec185f407 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/crossval_errors.txt @@ -0,0 +1,24 @@ +2.396012140755679365e+00 +4.480711277911260115e-01 +2.987724645529541867e+00 +4.277000771411354485e+00 +8.508079482254223835e-01 +3.462941354139007277e+00 +4.003849617755437684e+00 +3.053093833108082134e+00 +4.069386060501227753e+00 +1.160811959330638521e+00 +1.068969419958323996e+01 +2.050613156701358797e+00 +2.598761125557944029e-01 +1.744324021645896394e+00 +1.059052432611495398e+00 +1.831625933590129485e+00 +3.212063708729713252e+00 +1.474694824562631679e+00 +3.835499616059664163e+00 +1.586777748179354131e+00 +1.313007222469624091e+00 +2.977945246196925133e+00 +7.088566085793469584e-01 +4.436947495199563996e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/01.png new file mode 100644 index 0000000000000000000000000000000000000000..dd882a8f19e9abcc96b9e04102719bd03f6b6566 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62a4fd2d8ce7b7b795c5d24b12fea707c760af4fef0c06435927a9d75b5a924a +size 837810 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/02.png new file mode 100644 index 0000000000000000000000000000000000000000..fb568f60b774b17ce78b9af75475e46540a94264 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f540475363f6e65ccd0b9b37f132d1f3b2cfbd670ca39bd055f532566ba02c34 +size 1081407 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..055872e435cfe2718d5eebeb8266829fdbd0fe88 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/corrs.txt @@ -0,0 +1,16 @@ +5.720838800000000219e+02 2.991620199999999841e+02 4.506014900000000125e+02 3.717636600000000158e+02 +5.265157500000000255e+02 2.894499799999999823e+02 3.801778299999999717e+02 3.718808900000000222e+02 +6.661703400000000101e+02 3.181354799999999727e+02 5.339412200000000439e+02 3.701223899999999958e+02 +5.609456800000000385e+02 1.807470600000000047e+02 2.540061200000000099e+02 2.981746999999999730e+02 +4.714216400000000249e+02 2.935240999999999758e+02 2.144760200000000054e+02 3.865360900000000015e+02 +5.766269300000000158e+02 3.185162300000000073e+02 4.876987399999999866e+02 3.795601899999999773e+02 +7.010207799999999452e+02 9.238567299999999705e+01 4.310612300000000232e+02 2.756551900000000046e+02 +5.256166200000000117e+02 2.393388099999999952e+02 3.795722999999999843e+02 3.411927499999999895e+02 +5.705888999999999669e+02 2.400422100000000114e+02 4.503573499999999967e+02 3.406065800000000081e+02 +6.636008900000000494e+02 2.423675100000000100e+02 5.344295200000000250e+02 3.397956300000000169e+02 +7.031309800000000223e+02 1.568639799999999980e+02 4.430190299999999866e+02 3.019154500000000212e+02 +5.248059264207870456e+02 2.251150199168595236e+02 3.790005248323007550e+02 3.295326686123989930e+02 +5.710311283966667588e+02 2.707436201554564263e+02 4.491400054071735326e+02 3.548917961443609670e+02 +5.556861938875076703e+02 1.970657522714686820e+02 2.508518240386408706e+02 3.072043463526618439e+02 +6.934073850204836162e+02 1.585616007408017367e+02 4.255523570510792410e+02 3.026881814837404363e+02 +6.839389858118621532e+02 2.837548791659100971e+02 5.470150264116562084e+02 3.555176696442887305e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2e3cfa5881603cd96fd7b37a87205da61ef67ae --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/crossval_errors.txt @@ -0,0 +1,16 @@ +2.089796770444086516e+00 +3.560139367682131351e-01 +1.824197790271693709e+00 +1.579644173121356632e+00 +5.693810774341417913e-01 +5.688931313970732262e+00 +2.318580807601463345e+00 +3.531501782675507961e+00 +2.223367754851889533e+00 +9.060112503001150897e-01 +3.469717531158602153e+00 +3.655183345533559169e+00 +3.059823892691055924e+00 +2.766381135592363094e+00 +3.540309193863722115e+00 +1.130637216535496786e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/01.png new file mode 100644 index 0000000000000000000000000000000000000000..88e667f0235eee842466d8d368b25e2e693ed3b1 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18b637e3a4bb1e7211f2d7af844a00ba96bc2529e99701b0e20818eb0b88493f +size 362970 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/02.png new file mode 100644 index 0000000000000000000000000000000000000000..733b0f62a17af04175905d6df843be81c4cebcc3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b6ed4c9fa8fb005a7afdd4dc23cf3d902ded49ea42782b1cd99fde9511a4e71 +size 396068 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0fc970672499a59f2d23651d1ec8b8670474e96 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/corrs.txt @@ -0,0 +1,29 @@ +1.699193999999999960e+02 2.080499700000000018e+02 1.176834999999999951e+02 3.173487400000000207e+02 +2.292048599999999965e+02 2.058467400000000112e+02 2.624721400000000244e+02 3.187410199999999918e+02 +1.994497399999999914e+02 2.146790100000000052e+02 1.887659500000000037e+02 3.375730399999999918e+02 +1.795626200000000097e+02 1.855245799999999861e+02 1.436619100000000060e+02 2.663481100000000197e+02 +2.182769199999999898e+02 1.848357000000000028e+02 2.359317700000000002e+02 2.660797499999999900e+02 +1.542284699999999873e+02 2.212005000000000052e+02 4.179829199999999645e+01 3.534005700000000161e+02 +2.153751399999999876e+02 2.377072200000000066e+02 2.079303099999999915e+02 4.203798499999999763e+02 +2.204801500000000090e+02 1.973710199999999872e+02 2.409243899999999883e+02 2.938661200000000235e+02 +1.779310399999999959e+02 1.991198400000000106e+02 1.400422100000000114e+02 2.972416799999999739e+02 +2.127223500000000058e+02 2.496951099999999997e+02 1.978181600000000060e+02 4.553668999999999869e+02 +2.283880399999999895e+02 2.374727599999999939e+02 2.469689899999999909e+02 4.207315500000000270e+02 +2.189803200000000061e+02 2.132964200000000119e+02 2.351401899999999898e+02 3.365421499999999924e+02 +1.708922000000000025e+02 2.512869200000000092e+02 7.552689499999999612e+01 4.546662699999999973e+02 +1.787323338049845063e+02 2.414601446406908281e+02 1.001912872441162392e+02 4.245428804908389111e+02 +1.981264334817288955e+02 2.411944720423792887e+02 1.559283702540789989e+02 4.242349408057009441e+02 +1.787323338049845063e+02 2.527512300689324434e+02 9.895952850356457020e+01 4.590321252262854728e+02 +1.819204049847233193e+02 2.527512300689324434e+02 1.085056587428399553e+02 4.587241855411475626e+02 +2.062294477302316977e+02 2.467735966069221831e+02 1.808714847502502323e+02 4.460986584504929624e+02 +2.145981345770460678e+02 2.407959631449119229e+02 2.070463579869730495e+02 4.288540360827697100e+02 +1.973835808866980983e+02 1.868103617409780099e+02 1.870956487489719677e+02 2.717154015176403732e+02 +1.973835808866980983e+02 1.927719464278879400e+02 1.868461254080040419e+02 2.854391852708775446e+02 +1.718061239043292971e+02 1.418526518919348405e+02 1.330619491545299979e+02 1.649613076970173893e+02 +1.828660773958816606e+02 1.412996542173572152e+02 1.584983998654981860e+02 1.625951262355319500e+02 +2.218524134536037309e+02 1.410231553800684026e+02 2.501879314980578783e+02 1.578627633125611283e+02 +1.626198460074653553e+02 2.383619460711001352e+02 5.223644950679010890e+01 4.145819853141157409e+02 +1.631010251149683086e+02 2.424519684848752945e+02 5.223644950679010890e+01 4.251941023577143710e+02 +1.877643568234823022e+02 2.458244282717295732e+02 1.269324872222991871e+02 4.371435348179112452e+02 +1.746148151192849696e+02 1.935591612449198919e+02 1.314704047950642689e+02 2.889048941075864718e+02 +2.287110309909827777e+02 1.759154723759968704e+02 2.635742274688911948e+02 2.470552098254199223e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2c4648311a92a8c241d75278c3062eebbee5f5a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/crossval_errors.txt @@ -0,0 +1,29 @@ +4.094747428255010924e+00 +5.334159760209342238e+00 +2.691091772186884334e-01 +1.560500265154325140e-02 +3.124908881356045054e+00 +5.995524641871649685e+00 +1.064547777858604327e+00 +6.520279900533545003e+00 +2.616296339503305646e+00 +2.968194822467821403e+00 +3.153913249945764274e+00 +3.083370555874225261e+00 +1.787218154227150801e+00 +5.843889444021538315e-01 +3.743974298750481378e+00 +5.431708954181251325e-01 +1.817303339677495411e-01 +3.599477957313317877e-01 +4.085401545563981385e-01 +8.127999326452226558e-01 +1.391102985596567976e+00 +1.998585421031668696e+00 +9.684885938784187909e-01 +5.681982249348700442e+00 +2.961368819060811841e+00 +3.282659730971194123e-01 +1.145096508198633289e+00 +4.981046127640841981e+00 +4.610934227319352985e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/01.png new file mode 100644 index 0000000000000000000000000000000000000000..34b104d65cf96f3bd7746f39cef0a5d9ba7e0c6a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:198c31a0d634f2730279b292f5b9657568337707cb08d4afa3c8b5bae6af3200 +size 583920 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/02.png new file mode 100644 index 0000000000000000000000000000000000000000..1c8073bc28774c599bb47b82e4cac84a2f98bc3a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b78338df059ad86b55ec507003d6b7695364429f1e8fe021d3e10fe758ea7a44 +size 546186 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..17367ca09d5f4a30996ca64f1360e1148a998a5f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/corrs.txt @@ -0,0 +1,12 @@ +6.312811599999999999e+02 1.731887600000000020e+02 4.330541999999999803e+02 2.733132499999999823e+02 +6.462870199999999841e+02 1.682649600000000021e+02 4.463015599999999949e+02 2.687411500000000046e+02 +2.651443199999999933e+02 2.370269499999999994e+02 1.882136900000000139e+02 3.347089100000000030e+02 +1.669149299999999982e+02 2.277354400000000112e+02 1.235839100000000030e+02 3.306541599999999903e+02 +7.190660699999999679e+02 1.947026299999999992e+02 5.051288099999999872e+02 2.947001999999999953e+02 +5.631307199999999966e+02 2.854144200000000069e+02 4.356277799999999729e+02 3.759208699999999794e+02 +5.389725700000000188e+02 2.043652099999999905e+02 3.457382200000000125e+02 2.980709100000000262e+02 +1.657813299999999970e+02 2.919350099999999770e+02 1.120262299999999982e+02 3.869109199999999760e+02 +2.623597700000000259e+02 2.942409499999999980e+02 1.863379600000000096e+02 3.861978399999999851e+02 +6.702347700000000259e+02 2.667869200000000092e+02 4.865651399999999853e+02 3.571412700000000200e+02 +6.644581799999999703e+02 1.676787999999999954e+02 4.610729499999999916e+02 2.699134799999999927e+02 +6.771543175316711540e+02 1.770803000551282196e+02 4.715674571393174119e+02 2.789517777805050400e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..2478510d2a7fb6c2df4e60ce18608a1c138b66e9 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/crossval_errors.txt @@ -0,0 +1,12 @@ +2.942755404534845853e-01 +9.185852889156850276e-01 +3.346748789702642668e+00 +2.162153226798289030e+00 +1.122322047852059246e+00 +6.403388640390542896e+00 +5.048211887924723307e+00 +2.993677282433861997e+00 +4.056310620435071179e+00 +2.749595016855273855e+00 +1.479053446042520203e+00 +2.570293849026595190e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..732fed30a1bd60aee698b30b4fb000c59f344ef7 Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e7391ff899ad3f4b3f9ad8be103378d2625425bd Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/01.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b380871be1e0c70c7176328b680573d1e880be8c --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/01.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bf7ff2c7b8d279888a54057c2f8a485fb42a6d4be9443550fd58d215ce91670 +size 209291 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/02.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3e9fdbb4706c2e7db4d6014a8253625c6aa2f579 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/02.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d4ea5306dda72998b7b3eb240098835f4f9f04732a89b166878d6add7956558 +size 233057 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..098d2e39bd91a58f3c335dbbdf75200eb0e7188e --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/corrs.txt @@ -0,0 +1,50 @@ +3.435071438397347947e+02 6.959497241954786659e+02 1.781599902477694854e+02 6.095230902465996223e+02 +3.805172005450424990e+02 7.016261132607098716e+02 2.785537957219011673e+02 5.854025006196978893e+02 +3.961840343650807199e+02 7.029884466363654383e+02 3.170163575593931569e+02 5.704086205813534889e+02 +4.039039234937952187e+02 7.483995591582153111e+02 2.739904409276225010e+02 6.877520295770918892e+02 +3.955028676772529366e+02 7.309162808373031339e+02 3.163644497316391266e+02 6.551566381893867401e+02 +4.050392013068414485e+02 7.084377801389873639e+02 3.906819420956067006e+02 5.795353301699109352e+02 +3.948217009894252101e+02 7.100271690772522106e+02 3.228835280091801678e+02 5.925734867249930176e+02 +3.283572984315276813e+02 6.949318290264856159e+02 1.321443203711017418e+02 6.181286591705385263e+02 +2.433024337899045406e+02 5.431589549166819779e+02 1.129317308936844029e+02 3.525648714114240079e+02 +3.179748834014487215e+02 5.187761142271981498e+02 4.267970256861472080e+02 2.272259253061831714e+02 +3.057834630567068075e+02 5.299515828765449896e+02 3.988288311006801905e+02 2.686602876550231258e+02 +2.488901681145779321e+02 5.076006455778514237e+02 1.139675899524054330e+02 2.676244285963020957e+02 +2.605736126116221953e+02 4.689944811528353057e+02 7.460494572100719779e+01 1.661102408416440994e+02 +2.554938541346464262e+02 4.867736358222505828e+02 9.221454971926425515e+01 2.137597575428100072e+02 +2.204435206435134091e+02 4.639147226758595366e+02 4.166529727979423114e+00 2.034011669556000470e+02 +2.712411054132713275e+02 4.217527273169604314e+02 2.631312944082292233e+02 7.288292555675411677e+01 +2.930840668642672426e+02 4.293723650324241703e+02 3.594661868692821827e+02 6.148847591082312647e+01 +3.849691714495721726e+02 6.482676180430371460e+02 3.231634799028571479e+02 4.351514987440575624e+02 +4.471865534723270912e+02 7.245437354267030514e+02 6.353598906562697266e+02 5.831930225529338259e+02 +4.217611810111051227e+02 7.257402235425252002e+02 6.091756755608221283e+02 6.154197488242539293e+02 +2.859597798652843039e+02 5.743844768910154244e+02 3.473335246063475097e+02 3.928539205129500260e+02 +2.425870856667291378e+02 5.561380331247266895e+02 1.197322703151500036e+02 3.878184945330563096e+02 +2.022056117577295424e+02 5.366951012426158059e+02 3.312294346097746711e+01 3.858043241410987321e+02 +2.798073496144208434e+02 5.806978845794687913e+02 2.953217305659241561e+02 4.074589025475969493e+02 +2.585312994386380296e+02 5.667866210029953891e+02 1.757320048165831849e+02 3.942236759824524484e+02 +2.620773078012684891e+02 5.866988218085356266e+02 2.083473845664034343e+02 4.362927889930901983e+02 +2.499390484061103734e+02 5.776974159649353169e+02 1.360263251211497959e+02 4.268390557322727545e+02 +2.538995865089028712e+02 3.916563015146545581e+02 1.363968836700531710e+02 1.849164094976345041e+01 +2.400746446444624667e+02 3.985687724468747319e+02 1.254220968370807441e+02 5.580591618186952019e+01 +3.045910400118511348e+02 4.206886794299794587e+02 4.744203181256031030e+02 3.714877856581654214e+01 +2.995218946615563027e+02 4.607810108368566944e+02 4.535682231429553894e+02 1.227521158630012224e+02 +2.004431446330665381e+02 3.985687724468747319e+02 4.420867427308530750e+01 1.029974995636508766e+02 +2.124247609155815724e+02 3.815180108140648372e+02 5.518346110605773447e+01 5.361095881527501206e+01 +1.981389876556598040e+02 4.252969933847928701e+02 3.652632349000464274e+01 1.578714337285127840e+02 +2.064339527743240410e+02 5.008733422437339868e+02 3.103893007351837241e+01 2.994461838738567963e+02 +2.105814353336561737e+02 4.861267375883307977e+02 2.445405797373496171e+01 2.632293873250479237e+02 +2.474479469721640328e+02 4.497210573453043025e+02 1.034725231711362312e+02 1.468966468955404707e+02 +4.037349278460346795e+02 8.149293409312576841e+02 3.764499032891841352e+02 8.998383716029238713e+02 +3.392049459191438814e+02 7.821174857141945722e+02 1.668952560366415128e+02 8.572170874159660343e+02 +3.935267951118372594e+02 7.883152805885287080e+02 3.080782599059392624e+02 8.252511242757477703e+02 +3.268093561704755530e+02 7.795654525306451887e+02 1.180583679057523341e+02 8.616568045187741518e+02 +3.286322370158679860e+02 7.591491870622503484e+02 1.180583679057523341e+02 7.968369348177757274e+02 +4.011828946624853529e+02 7.573263062168580291e+02 2.858796743918987886e+02 7.204738006494764022e+02 +3.475901978079489254e+02 7.518576636806808438e+02 1.438087271020393700e+02 7.577674243130644527e+02 +3.402986744263793071e+02 7.423786832846403740e+02 1.278257455319301243e+02 7.364567822195855342e+02 +3.282676608467895107e+02 7.358163122412277062e+02 8.520446134497240109e+01 7.320170651167774167e+02 +3.068108708836140295e+02 6.952626313721491442e+02 1.212063200657838706e+02 6.466681953866445838e+02 +2.906379928440490517e+02 5.406954008811891299e+02 3.609134926918036399e+02 3.078481908902177793e+02 +2.981969097652342384e+02 5.297006126321925876e+02 3.754353015445329902e+02 2.742187388102130399e+02 +2.820483145245204355e+02 5.602798674497142883e+02 3.685565499827138183e+02 3.674640377593170797e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..4e030d242bb4adec67a3a3b2eed901416184964f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/crossval_errors.txt @@ -0,0 +1,50 @@ +4.852766667173328829e-01 +5.773447457216664136e-02 +9.383705325914409867e-01 +3.297959889110597231e+00 +2.931178080395086649e+00 +1.276087969696925839e+00 +1.243149818607734325e+00 +3.525279717474206986e+00 +3.009354776434531864e+00 +2.895347087318727475e+00 +3.539998841178534139e+00 +2.350301461569856531e+00 +9.330662437729901892e-01 +7.953587772446200910e-02 +2.113451162925480187e+00 +3.793867642603225399e+00 +2.037081321239481713e-01 +1.528024863707494685e+00 +6.803045160843667283e+00 +1.168806582172539699e+01 +2.174503807975670888e+00 +2.372178250210668082e+00 +9.428519074525644195e-01 +2.516024716133358208e+00 +3.173357005405610942e-01 +2.952660548568692978e+00 +1.363268908279729741e-01 +6.417730434506982995e+00 +2.527729278270938323e-02 +1.552720762154066403e+00 +5.804936135579450429e+00 +1.385047607413002257e+00 +2.062289662207262175e+00 +1.759740112676018597e+00 +4.297790461315768695e+00 +3.494339002420721041e+00 +7.779047508699438174e-01 +6.133546870807665918e+00 +1.148380196528532515e+00 +1.967728399278650508e+00 +5.405149290397431194e+00 +7.113500860836646789e-01 +1.241186207976585854e+00 +1.504359218668552267e-02 +1.474388622498197021e+00 +3.022153888227849805e+00 +5.445628133836079299e-02 +3.960920553836733138e-01 +4.612436046632942821e-01 +2.592764738552643777e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..27c87b2ee899e9b6aa4d8c3425d5699f448bad8c Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/01.png new file mode 100644 index 0000000000000000000000000000000000000000..477037d16e537445f82fed5ea99ddf3bdb1a96be --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad67e56da8504b8bd3f95b72e736b9d06def8672ac948f05f4f2204d187034e4 +size 760888 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/02.png new file mode 100644 index 0000000000000000000000000000000000000000..18f7fb65fe31d9058c6d5070c78bb3e696a481b3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81140ccca5518980d3feb3ec30dd56cfe3b6fbde661d89f6e551bd4d1e889168 +size 673034 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..be564efc8d3bd58e5330cf5227328a0510c23fb3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/corrs.txt @@ -0,0 +1,34 @@ +5.632709800000000087e+02 1.313976600000000019e+02 2.893448599999999828e+02 2.536246800000000121e+02 +5.506097899999999754e+02 1.301081000000000074e+02 2.860623299999999745e+02 2.525695900000000051e+02 +4.947821900000000142e+02 1.264932299999999969e+02 2.713383099999999786e+02 2.487299399999999991e+02 +4.522330799999999726e+02 1.441868100000000084e+02 2.543481299999999976e+02 2.505078000000000031e+02 +3.927916500000000042e+02 1.576599899999999934e+02 2.347196800000000110e+02 2.504593900000000133e+02 +3.001216800000000262e+02 1.267567499999999967e+02 2.091220199999999920e+02 2.363021699999999896e+02 +6.090455600000000231e+02 1.744962800000000129e+02 3.033945299999999747e+02 2.699415500000000065e+02 +4.882655399999999872e+02 2.175810700000000111e+02 2.666683499999999754e+02 2.783823499999999740e+02 +7.410004500000000007e+02 2.281611200000000110e+02 3.389677399999999921e+02 2.931838300000000004e+02 +4.153606199999999831e+02 2.576769300000000271e+02 2.458869300000000067e+02 2.926073499999999967e+02 +5.188150200000000041e+02 1.337616899999999873e+02 2.764965799999999945e+02 2.524814000000000078e+02 +6.467946899999999459e+02 1.982946399999999869e+02 3.093734299999999848e+02 2.772100199999999859e+02 +2.725054383200736652e+02 1.379274566053054514e+02 2.004945675835598990e+02 2.362944654439720296e+02 +2.693994122349491249e+02 1.472455348606790722e+02 1.980667729146825877e+02 2.411500547817267943e+02 +4.464826675509639813e+02 1.676949277233762814e+02 2.507904124676491620e+02 2.574015523786398489e+02 +3.146731428626089269e+02 1.491177463914604004e+02 2.128540067898943562e+02 2.443481654787672142e+02 +5.137143714188497370e+02 1.730026938182093090e+02 2.736338395424262444e+02 2.639282458285761663e+02 +5.632535216372918967e+02 1.650410446759597107e+02 2.879109814641619209e+02 2.639282458285761663e+02 +2.863650570234991619e+02 1.655093765814348217e+02 2.046956399774740021e+02 2.474195504486096411e+02 +1.077564275177346786e+02 1.663604524797651720e+02 1.403882031816047515e+02 2.366652162690947137e+02 +2.072420979386583042e+02 1.381041673897987039e+02 1.768318838209144133e+02 2.330608962058662996e+02 +1.672123607278724648e+02 1.434022208441674593e+02 1.616136435539499416e+02 2.318594561847901332e+02 +2.534226267199915483e+01 2.028581540543052029e+02 1.067478825914727736e+02 2.418714563604246734e+02 +1.683897059399544105e+02 1.701348880022814569e+02 1.616136435539499416e+02 2.416358834803155844e+02 +1.370441600788676624e+02 1.634171506956903954e+02 1.503214646539008186e+02 2.360283531679565954e+02 +3.235426251126277748e+02 1.920638149445047986e+02 2.137522429888200293e+02 2.553334511431434635e+02 +2.835902178531248978e+02 1.889905528476199663e+02 2.029687883345565069e+02 2.553334511431434635e+02 +2.594431585204583826e+02 1.911857400596805689e+02 1.961305000172186510e+02 2.550704400540150800e+02 +2.541747092115129476e+02 2.052349382168683860e+02 1.924483447694213680e+02 2.600676507474543087e+02 +3.015907529920217485e+02 2.096253126409895629e+02 2.090180433845091272e+02 2.645388392626367704e+02 +4.578482005041917660e+02 2.124169269111976064e+02 2.558986741134780800e+02 2.746111778115074458e+02 +5.788114731647209510e+02 2.259128292328269083e+02 2.944315520780464226e+02 2.855384118611611939e+02 +5.383237661998331305e+02 1.989210245895683329e+02 2.823540828652712662e+02 2.731733838576056428e+02 +5.978057060618289142e+02 2.034196586967780718e+02 2.975946987766304233e+02 2.763365305561895866e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..7533d4e3db031e37eed91a7e7883687beebb63ee --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/crossval_errors.txt @@ -0,0 +1,34 @@ +1.035615702231201585e+00 +1.196879631355693974e+00 +6.112053699849899768e-01 +1.006642457920902262e-01 +3.007740766737664884e-01 +1.879998269793030863e+00 +1.811132848182660293e+00 +3.217420525980498436e+00 +6.823652600275327551e-02 +3.178430665916133491e-02 +2.209381490455120112e+00 +2.370019498457026419e+00 +6.773946495656722355e+00 +3.532300831263593288e+00 +2.910156346175509778e+00 +2.171179401819538946e+00 +2.810952725959120269e+00 +2.118168175101782680e+00 +2.200726898729873926e-01 +1.483904459267001741e+01 +1.082194779711510346e+00 +1.053720698982808024e+01 +5.646140718108199508e+00 +9.828258767995476930e+00 +1.611215092864966536e+01 +1.013231868435379290e+01 +1.197727714496128693e+00 +2.927725272338077112e+00 +3.251141531427588660e+00 +3.606444000992193200e+00 +4.698983654702942658e+00 +9.882536248838456050e-01 +1.239308029690608270e+00 +4.261074506023826203e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/01.png new file mode 100644 index 0000000000000000000000000000000000000000..5a161f26bbc1b77cb2eed80c2c211fc0db7f74ba --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bcc33b07aebaf526136256c3ef9f4ea10f45e03ebd495b44a7b0cce02946674c +size 973773 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/02.png new file mode 100644 index 0000000000000000000000000000000000000000..da48563252d1b9c6c09d9f446098a40aa3405765 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e11e32c52622d12eaa12787ab7635eff03683301a947f17c6db7fee3594baf2 +size 753067 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..b50bbd7590e4aaba33bff7d776da4e8881fd38e4 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/corrs.txt @@ -0,0 +1,32 @@ +2.360286199999999894e+02 9.481852299999999900e+01 2.781969799999999964e+02 7.274162200000000666e+01 +4.471002300000000105e+02 1.218223899999999986e+02 4.180615500000000111e+02 8.753985900000000697e+01 +7.609214600000000246e+02 2.900361500000000206e+02 6.275490600000000541e+02 1.801840300000000070e+02 +1.597529800000000080e+02 2.783806000000000154e+02 2.159903700000000129e+02 1.841312299999999880e+02 +4.597864999999999895e+02 3.497691100000000120e+02 4.174805000000000064e+02 2.315654000000000110e+02 +4.714313300000000027e+02 2.345709699999999884e+02 4.179397500000000036e+02 1.500723600000000033e+02 +3.982929700000000253e+02 3.084224100000000135e+02 3.645750899999999888e+02 1.984153599999999926e+02 +2.304207800000000077e+02 1.884263599999999883e+02 2.663360099999999875e+02 1.280013800000000117e+02 +6.925498999999999796e+02 1.784217600000000061e+02 5.756890600000000404e+02 1.011625700000000023e+02 +6.969950800000000299e+02 3.943423099999999977e+02 5.771937299999999595e+02 2.589903800000000160e+02 +2.236513400000000047e+02 3.598176199999999767e+02 2.519356899999999939e+02 2.340100099999999941e+02 +4.623742799999999988e+02 4.189536899999999946e+02 4.180279300000000262e+02 2.751502500000000282e+02 +4.317743100000000140e+02 7.095999299999999721e+01 4.198057999999999765e+02 6.238532299999999964e+01 +4.414981000000000222e+02 3.498863400000000183e+02 4.029435700000000224e+02 2.332066700000000026e+02 +4.704547299999999836e+02 3.528171699999999760e+02 4.251006600000000049e+02 2.340272999999999968e+02 +2.306552499999999952e+02 1.557182699999999897e+02 2.644602800000000116e+02 1.066649199999999951e+02 +2.063879599999999925e+02 1.704896699999999896e+02 2.509784500000000094e+02 1.153401900000000069e+02 +2.550397800000000075e+02 1.704896699999999896e+02 2.761836099999999874e+02 1.149884899999999988e+02 +7.186847736100646671e+02 1.612839620694029463e+02 5.970815920021781267e+02 8.742081692709524532e+01 +6.520464260002647734e+02 1.600024553845991022e+02 5.450039535103417165e+02 9.082776523964528792e+01 +7.242379692442146961e+02 1.873412646604144527e+02 5.995151265111425118e+02 1.066457395479133794e+02 +4.569601499328094860e+02 4.444818995056880340e+02 4.138725246013769379e+02 2.920795103148053613e+02 +4.246167624506033462e+02 3.411289695813454728e+02 3.828439599583084032e+02 2.230627029778865733e+02 +5.301583426556967424e+02 3.416153363104012897e+02 4.538906733933718556e+02 2.221927432215388762e+02 +6.667335001808941115e+02 1.035710570582327819e+02 5.605959728884329252e+02 5.568705899556496774e+01 +2.431912216314050852e+02 3.232394024770093779e+02 2.702069573547900063e+02 2.114583240039939938e+02 +6.277595469798195609e+02 1.869733877933344957e+02 5.310581761039995854e+02 1.130581252647631914e+02 +7.397337531878540631e+02 2.504353068495012735e+02 6.119037385776150586e+02 1.497963717166607296e+02 +4.698633348550070536e+02 2.115440656563695256e+02 4.163448779995451901e+02 1.327446834734667220e+02 +4.874772549326754074e+02 1.687674026106035399e+02 4.422008912230889450e+02 1.149003926572182763e+02 +3.245484942142433056e+02 1.505244139587327936e+02 3.507944219398570453e+02 1.152645618575498929e+02 +1.490383620117622741e+02 3.789503133324984105e+02 2.135026334148435865e+02 2.485933319605100849e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..0224275375750f200ecedbc5425595014ed1e3c4 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/crossval_errors.txt @@ -0,0 +1,32 @@ +5.191993455264680424e-02 +1.465544064612110287e+00 +5.936445158653792520e-01 +9.030548139514468220e-01 +4.941321813867107782e-01 +8.177992914893557064e-01 +2.908001862198693832e+00 +1.591621936844796747e-01 +3.293545072056112133e-01 +1.623973304517367655e+00 +1.129899579273499599e+00 +3.384341811700881220e+00 +3.831337751029900041e+00 +5.062077129708542955e+00 +1.124576558576982688e+00 +9.286197303745061804e-01 +3.181643372606760334e+00 +3.213996797500379365e+00 +1.887924995137213013e+00 +1.325644623510203290e+00 +1.121749737644925915e+00 +3.512606009773121318e+00 +2.341729310650397800e+00 +2.537510842859704852e+00 +3.453915924287136896e+00 +2.267845496619912193e+00 +2.052534989207430272e+00 +3.696536742617343219e+00 +3.207315167365495157e+00 +8.871948800972830895e-01 +1.100981285460365527e-01 +2.052016565246739255e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/01.png new file mode 100644 index 0000000000000000000000000000000000000000..615009b8f25cc3e32c086e37e6f4382cb14afbf1 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ffd41d1811e888a0a0c842b67735eb00aa5184558a48949783f107fb41616df +size 402968 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/02.png new file mode 100644 index 0000000000000000000000000000000000000000..24681732835bf9bcf421dac2d071a0b5addaae7f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81c79e7aad63e61801dcbe3d2acd0fdb5d6389f1d36c3afcfd651926c1848ebb +size 607963 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..701ba26d2e27c0c1087836983d9f5bd3bd59174a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/corrs.txt @@ -0,0 +1,40 @@ +1.958494000000000028e+02 6.222665799999999692e+00 3.766220999999999890e+02 4.925415799999999678e+01 +3.864019000000000119e+02 2.038731400000000065e+02 4.052995300000000043e+02 1.847192300000000103e+02 +3.676445800000000190e+02 1.994182800000000100e+02 3.763429100000000176e+02 1.814366899999999987e+02 +2.598528599999999855e+01 2.009665200000000027e+02 1.552787399999999991e+02 2.763817199999999730e+02 +1.751416399999999953e+02 1.976200100000000077e+02 1.787758900000000040e+02 2.392446999999999946e+02 +2.103825199999999995e+02 1.999307900000000018e+02 1.993615700000000004e+02 2.291325500000000090e+02 +1.947942999999999927e+02 1.565561999999999898e+02 3.251428599999999847e+02 2.015612900000000138e+02 +1.921269800000000032e+02 2.566903699999999731e+02 1.854097700000000088e+02 3.081927499999999895e+02 +2.822226699999999937e+02 2.667928400000000124e+02 3.187434400000000210e+02 3.170346900000000119e+02 +3.175098899999999844e+02 2.493250800000000140e+02 3.567270300000000134e+02 2.835059699999999907e+02 +2.488111800000000073e+02 2.758197999999999865e+02 2.940072200000000180e+02 3.350886199999999917e+02 +2.768154099999999858e+02 3.384818599999999833e+02 3.104392500000000155e+02 4.314336299999999937e+02 +2.413777599999999950e+02 2.045910700000000020e+02 2.937018600000000106e+02 2.384143899999999974e+02 +3.662948400000000220e+02 2.708717899999999759e+02 3.770028500000000236e+02 3.046951099999999997e+02 +1.327947999999999951e+01 2.718929999999999723e+02 1.429682099999999991e+02 3.506149399999999901e+02 +1.903684800000000052e+02 3.095667300000000068e+02 1.749459200000000010e+02 3.826023400000000265e+02 +1.946819099999999878e+02 9.647915500000000577e+01 3.337697099999999750e+02 1.337759000000000071e+02 +1.952680799999999977e+02 5.459955200000000275e+01 3.610968199999999797e+02 9.506643599999999594e+01 +1.860018000000000029e+02 1.009136499999999970e+01 3.681813099999999963e+02 5.652262199999999837e+01 +3.155240499999999884e+02 2.760003899999999817e+02 3.526113700000000222e+02 3.268587299999999800e+02 +3.825456899999999791e+02 2.486293200000000070e+02 3.975715099999999893e+02 2.614097300000000246e+02 +2.244132299999999987e+02 9.858509200000000305e+01 3.674812400000000139e+02 1.278130100000000056e+02 +2.366054900000000032e+02 1.043295200000000023e+02 4.054648300000000063e+02 1.343780800000000113e+02 +2.180582805126952053e+02 2.673983552561883243e+02 2.153567351223074979e+02 3.225453098178572873e+02 +2.191047260962589576e+02 2.520504866972534046e+02 2.158411780145567604e+02 3.031675941278855930e+02 +2.187559109017377068e+02 3.092561785987379608e+02 2.076056488463187861e+02 3.816473426722710656e+02 +2.476600803710741729e+02 2.496246586406160759e+02 2.962334247218788050e+02 2.989252798449094826e+02 +2.762871664014870703e+02 2.492479864560053784e+02 3.184615294427110257e+02 2.939488384894992805e+02 +2.864573153859758463e+02 2.575347745174407237e+02 3.257603100973126402e+02 3.042334839573470049e+02 +2.480367525556848705e+02 2.571581023328300262e+02 2.959016619648515416e+02 3.098734508268119043e+02 +3.039725719703732238e+02 2.496246586406161327e+02 3.430119734627346588e+02 2.879771088630070608e+02 +1.931669929845307934e+02 1.731478943608864256e+02 3.276824680474070419e+02 2.227384219525362141e+02 +1.931669929845307934e+02 1.671728696567184898e+02 3.248873840844720462e+02 2.149121868563181579e+02 +2.208014822413075535e+02 1.753885286249494015e+02 3.623415091878012504e+02 2.193843211970141738e+02 +2.252827507694335338e+02 1.682931867887499777e+02 3.612234756026272180e+02 2.093220189304481664e+02 +2.472352169056881053e+02 3.121931499331572013e+02 2.880808778718156304e+02 3.891075898947697169e+02 +2.754026383392823618e+02 3.121931499331572013e+02 3.105887966431296832e+02 3.881498061172669622e+02 +2.858350166480209396e+02 3.124539593908756387e+02 3.211244181956597004e+02 3.881498061172669622e+02 +2.521905966023389283e+02 3.234079566150512051e+02 2.943064724255833653e+02 4.068265897785701100e+02 +3.922452753971548418e+02 2.490772611652885473e+02 4.145083365021756663e+02 2.622012393756584743e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a806d88120fb851486d0a0f937bf4095df7b301 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/crossval_errors.txt @@ -0,0 +1,40 @@ +1.619268508549998931e+00 +3.886848298702200832e+00 +2.588669894928778081e+00 +4.542110231958436106e-01 +3.446516145153667221e+00 +3.679736430290312832e+00 +8.399536963373055790e-03 +2.585070026123002407e+00 +6.175314108689570203e-01 +9.191016794758739561e-01 +1.035661036366828158e+00 +1.060660686176270406e+00 +6.119007703997929593e-01 +1.708246958790519354e+00 +6.962075281211588251e-01 +6.766506035606598690e-01 +1.950827625135682419e+00 +1.401927986179450983e-02 +7.156366045273916676e-02 +3.324444405591722163e-01 +6.130206632473679251e-01 +6.648246702775817418e-01 +1.672542182573007796e+00 +3.336871484617892625e-01 +3.435317825533382941e+00 +1.453007621079772382e-01 +3.963916321132272130e-02 +1.382163780362680061e+00 +1.401298777655322736e+00 +5.694420338034895668e-01 +4.020510732517361130e-01 +1.105574600334674251e+00 +5.122245513220390345e-01 +1.421834385649536903e-01 +9.107395820013965970e-01 +5.779344119513706302e-01 +4.792392571718611105e-01 +6.734883886232745365e-01 +1.843982514331961031e-01 +2.359130939770464508e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/01.png new file mode 100644 index 0000000000000000000000000000000000000000..69f2e2915d7f69643463008d672b15481da85b6a --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7ca21ddaf898862d1241f55198ea8bf660b77213ca0a3a6929064358fe74e34 +size 493636 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/02.png new file mode 100644 index 0000000000000000000000000000000000000000..05dd88b5d711c1c21952cd970a1d579ae5db92f0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:729bbec0906acedc9a489d40c73d13d101777bb9d2c2ee48b4fa901762063407 +size 341760 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..6bdfad54d2f1b131de03c827d8c1db96833e32b3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/corrs.txt @@ -0,0 +1,32 @@ +5.782488299999999981e+02 3.321297799999999825e+02 2.645256800000000226e+02 2.530823000000000036e+02 +5.151237200000000485e+02 3.282901299999999765e+02 1.738835500000000138e+02 2.547565199999999948e+02 +5.437307200000000194e+02 2.607585700000000202e+02 2.158755600000000072e+02 1.678416599999999903e+02 +4.656277600000000234e+02 1.522212900000000104e+02 1.256151799999999952e+02 5.771545400000000114e+01 +5.945356100000000197e+02 1.904791199999999947e+02 3.027003100000000018e+02 7.952084600000000592e+01 +6.255754600000000210e+02 2.941043500000000108e+02 3.369732999999999947e+02 2.012910899999999970e+02 +4.581635600000000181e+02 2.818820000000000050e+02 9.779877999999999361e+01 2.046366099999999904e+02 +4.146572199999999953e+02 3.733990099999999757e+02 6.323125600000000190e+01 3.186344300000000089e+02 +5.010740599999999745e+02 4.452488299999999981e+02 4.459785200000000316e+01 4.015304100000000176e+02 +5.371559700000000248e+02 3.425732199999999921e+02 2.075248799999999960e+02 2.702041899999999828e+02 +5.083822900000000118e+02 3.983911400000000071e+02 1.642316299999999956e+02 3.398770099999999843e+02 +5.221781200000000354e+02 3.980588000000000193e+02 1.852397300000000087e+02 3.401686300000000074e+02 +5.729335600000000568e+02 3.989966699999999946e+02 2.583335900000000152e+02 3.389691799999999944e+02 +6.058200900000000502e+02 3.091112499999999841e+02 3.082761899999999855e+02 2.216287499999999966e+02 +4.797849699999999871e+02 2.989894100000000208e+02 1.290178600000000131e+02 2.216258300000000077e+02 +6.343098599999999578e+02 3.721481699999999933e+02 3.581887500000000273e+02 3.003527300000000082e+02 +5.014257600000000252e+02 4.854920099999999934e+02 4.433334700000000339e+01 4.526695399999999836e+02 +3.656204999999999927e+02 3.221445400000000063e+02 2.374288500000000024e+00 2.647601399999999785e+02 +4.079351500000000215e+02 2.986570800000000077e+02 8.063225099999999657e+01 2.342744999999999891e+02 +5.545559500000000526e+02 2.884277000000000157e+02 2.333967100000000130e+02 2.012610500000000116e+02 +5.073271899999999732e+02 2.014827899999999943e+02 2.081594499999999925e+02 1.116769400000000019e+02 +4.988272600000000239e+02 2.918962799999999902e+02 1.504106199999999944e+02 2.125329600000000028e+02 +6.400349200000000565e+02 4.073600000000000136e+02 3.692232700000000136e+02 3.471842700000000264e+02 +4.092644799999999918e+02 3.879359299999999848e+02 5.773772699999999958e+01 3.332048700000000281e+02 +4.194637799999999856e+02 3.879359299999999848e+02 6.793702399999999386e+01 3.333220999999999776e+02 +3.642394613790673361e+02 3.767043016888868578e+02 6.488157981007134367e-01 3.242811770475954631e+02 +3.751359519042346164e+02 3.569291892543239442e+02 1.372472163530062517e+01 3.012426762868147421e+02 +3.844181475367845451e+02 3.476469936217740155e+02 2.368731655888154819e+01 2.887894326323386167e+02 +4.130718818807430353e+02 3.658078111637195207e+02 6.166970970503348326e+01 3.087146224795003491e+02 +5.707847220192514897e+02 2.695261601368335391e+02 2.558068126731624261e+02 1.733098156215094718e+02 +6.054041967536991251e+02 2.711417356244410826e+02 3.075495409571985874e+02 1.757406820509608281e+02 +6.176364111598705904e+02 2.692953636386038738e+02 3.252601392289156479e+02 1.715734824576156257e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..38d7fa5119ff211ac4f06a1f7cd1eee30ea7534f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/crossval_errors.txt @@ -0,0 +1,32 @@ +8.715988390665808661e-01 +1.187855777575472560e+00 +1.488191769406739251e+00 +1.540662910329874125e+00 +2.503540793373755236e+00 +2.841651356896309100e+00 +2.629872184775364219e-01 +3.090908285052762494e+00 +2.992049923663860245e+00 +9.722442946457870994e-01 +4.463918797905627178e-01 +1.835074225796184066e+00 +2.557828320585105786e+00 +1.158792722156904764e+00 +1.650155615904922879e+00 +1.934096082935798488e+00 +4.625630744759416935e+00 +4.197963781019461216e+00 +5.461443548326146979e-01 +1.679977028297218733e-01 +9.915869307461986359e-01 +1.248967533235761707e+00 +3.646619983191577319e+00 +9.866973171593411696e-01 +2.850378296718754090e-01 +1.203073091924289351e+00 +1.284716092338700877e+00 +4.604175489226451368e-01 +1.112167057576260554e+00 +3.792437685853049967e+00 +3.502812721344579217e+00 +3.510454335316306018e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/01.png new file mode 100644 index 0000000000000000000000000000000000000000..66146c46a4400256843da6a8287d1d2556dea107 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fe6a0defd25979bb53bd4d7842c3b6e2730487764122463bcbee3e45fdf8f4d +size 597486 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/02.png new file mode 100644 index 0000000000000000000000000000000000000000..ec411be52ea5df3a9af7358be8640d44f61370f6 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:430b8c56acf2d7032aa3b4774784d82a5846ac3cdd2af66fa4b2386fd477900e +size 705899 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..e1f60a837fe1399550c2af4081806d6ae82fffd4 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/corrs.txt @@ -0,0 +1,35 @@ +3.223486300000000142e+02 2.839303400000000011e+02 2.607185000000000059e+02 2.346901599999999917e+02 +3.738558800000000133e+02 3.103271899999999732e+02 2.862871099999999842e+02 2.504682400000000086e+02 +5.751319499999999607e+02 2.281521899999999903e+02 4.111983200000000238e+02 2.059368800000000022e+02 +6.758038699999999608e+02 2.683759999999999764e+02 4.804883699999999749e+02 2.325014600000000087e+02 +6.953925500000000284e+02 1.296883400000000108e+02 5.107850500000000125e+02 1.445208199999999863e+02 +4.742061899999999923e+02 2.980799700000000030e+01 3.646923199999999952e+02 6.850869000000000142e+01 +2.896083800000000110e+02 7.866834099999999808e+01 2.492576500000000124e+02 9.834715400000000329e+01 +1.757568699999999922e+02 9.239535600000000670e+01 1.770270700000000090e+02 1.072966499999999996e+02 +9.089311700000000371e+01 1.456234200000000101e+02 1.333346799999999917e+02 1.387957499999999982e+02 +4.809985499999999803e+01 3.841322900000000118e+02 1.161294000000000040e+02 2.905360800000000268e+02 +1.514875099999999861e+02 5.156977799999999661e+02 1.701780899999999974e+02 3.710622099999999932e+02 +2.158129700000000071e+02 4.196958200000000261e+02 2.132144700000000057e+02 3.131641799999999876e+02 +1.912347800000000007e+02 3.351159799999999791e+02 1.951712599999999895e+02 2.637563999999999851e+02 +4.668785899999999742e+02 3.555844299999999976e+02 3.098036299999999983e+02 2.781965000000000146e+02 +5.274342500000000200e+02 3.454529099999999744e+02 3.402573100000000181e+02 2.732823900000000208e+02 +4.883730899999999906e+02 1.831799100000000067e+02 3.679458000000000197e+02 1.750723600000000033e+02 +1.623815600000000074e+02 2.169064400000000035e+02 1.714579799999999921e+02 1.866988599999999963e+02 +5.728382100000000321e+01 2.690503600000000120e+02 1.178298000000000059e+02 2.197908099999999934e+02 +6.048134099999999762e+02 2.930939000000000192e+02 4.289413499999999999e+02 2.442742399999999918e+02 +6.595461400000000367e+02 4.517810799999999745e+02 4.601274799999999914e+02 3.389769499999999880e+02 +7.747463400000000320e+02 3.500443799999999896e+02 5.464174000000000433e+02 2.836386800000000221e+02 +3.286211299999999937e+02 3.751343800000000215e+02 2.635514600000000200e+02 2.878494000000000028e+02 +4.516448399999999879e+02 1.776312200000000132e+02 3.483066400000000158e+02 1.717898299999999949e+02 +2.345733999999999924e+02 5.146814100000000280e+02 2.136747499999999889e+02 3.677011800000000221e+02 +6.956982956609938356e+02 4.448380497371611000e+02 4.897555306141204028e+02 3.345909281372238979e+02 +6.850766939846554351e+02 3.374418550097398111e+02 4.905243322592649520e+02 2.730867965256645675e+02 +7.794909311076631866e+02 4.082525328519956247e+02 5.458780507096682868e+02 3.153708870086115894e+02 +7.755681227403512139e+02 3.093362791412923798e+02 5.515505328082300593e+02 2.611494394884711028e+02 +7.633183634593142415e+02 2.678232060222228483e+02 5.455192559794936642e+02 2.326013958324519422e+02 +8.911064507318195638e+01 3.751099451192026208e+02 1.393658564074330002e+02 2.864113796803452487e+02 +1.193134095552586729e+02 3.668285419547622155e+02 1.526517208245552411e+02 2.806802224808023993e+02 +1.422090535981232620e+02 4.822810448943134816e+02 1.659375852416775388e+02 3.491936017298837100e+02 +1.378247813345959969e+02 4.939724375970528740e+02 1.638535280782073755e+02 3.562272946565955181e+02 +7.790639373305670290e+01 4.257726468310732457e+02 1.193068062090328141e+02 3.174117299869639055e+02 +5.257504287712141888e+01 4.272340709189156769e+02 9.820572742889750373e+01 3.174117299869639055e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..6bfcc6f9edeb89e6ef7afa5a15d6581b2c8d85cb --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/crossval_errors.txt @@ -0,0 +1,35 @@ +2.043379849940813120e+00 +8.169318573230145430e-01 +3.638987395587739204e+00 +2.649559860979942894e+00 +9.965473481391702304e-01 +7.991661399239784025e+00 +3.273218782217096612e-01 +6.738468738498821331e+00 +7.367010900988238964e-01 +1.904385438409952336e-01 +7.762423758884946956e+00 +2.328785188654431959e+00 +1.671697819648116257e+00 +1.777887754345692795e+00 +1.656063175527794584e+00 +2.665111671711846153e+00 +2.313613197174701597e+00 +1.391896839663567587e+00 +1.051739843204129166e+00 +3.623669403787067367e+00 +3.950511428855001395e+00 +2.630721197748733697e+00 +5.315878363056087963e+00 +1.207551898246299604e+00 +2.257441047339381335e-01 +4.487693040202633266e-01 +1.532589474507056217e-02 +7.002406873062206216e+00 +3.276273526147291948e-01 +1.302720833184545457e+00 +9.656081119072771335e-01 +1.766602521388588656e-02 +9.723366725220530249e-01 +4.123774285849783894e+00 +3.611205377490628088e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/01.png new file mode 100644 index 0000000000000000000000000000000000000000..71ba98f2229cb32a4c44658dad6063580745c0d1 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1c8df896d59d5afe51af5c2cd9e94c20fc41272179aa8118a4056092f92c74c +size 822512 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/02.png new file mode 100644 index 0000000000000000000000000000000000000000..c3c8bb78563d97117b17de9da23a6b2e07e89788 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3563b4ef12305c69c89249264f4d66c72c7cd0558da2fbf5e6c73a2d91562f2 +size 697851 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/03.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/03.png new file mode 100644 index 0000000000000000000000000000000000000000..51128df81280589a20ea36fb4cb8c350933e0036 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/03.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1ed226c9076112816b88b2495fbdfa6ef9b8cdba1bf89980e31c3c71722247c +size 730090 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c73430138216837f251073fca764b5cd41e7aa3 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/corrs.txt @@ -0,0 +1,27 @@ +4.280615799999999922e+02 1.554128499999999917e+02 4.518512200000000121e+02 2.504033300000000111e+02 +4.847582300000000259e+02 1.798080899999999929e+02 5.348963999999999714e+02 2.607818399999999883e+02 +2.941127000000000180e+02 1.590083500000000072e+02 2.959175900000000183e+02 2.538389400000000080e+02 +2.353940299999999866e+02 1.586372800000000041e+02 2.221174000000000035e+02 2.382013699999999972e+02 +2.054092899999999986e+02 1.781485000000000127e+02 1.566796099999999967e+02 2.132441300000000126e+02 +1.970556400000000110e+02 1.930091099999999926e+02 1.424913999999999987e+02 2.384619700000000080e+02 +5.121046999999999798e+02 1.857589800000000082e+02 6.022913099999999531e+02 2.330498299999999858e+02 +2.384227300000000014e+02 3.142075800000000072e+02 9.785242499999999666e+01 3.494049200000000042e+02 +4.045547500000000127e+02 2.504769400000000132e+02 4.488389900000000239e+02 3.224188800000000015e+02 +3.780600299999999834e+02 2.505941699999999912e+02 4.038214100000000144e+02 3.221844100000000140e+02 +2.540943000000000040e+02 1.889834099999999921e+02 2.462995600000000138e+02 2.800431800000000067e+02 +4.550273099999999999e+02 1.833949499999999944e+02 4.811752000000000180e+02 2.812610599999999863e+02 +4.109541699999999764e+02 1.833164400000000001e+02 4.311599299999999744e+02 2.832995700000000170e+02 +3.209515099999999848e+02 1.857987500000000125e+02 3.270717599999999834e+02 2.852663999999999760e+02 +3.323552999999999997e+02 2.624831500000000233e+02 3.413861299999999801e+02 3.458199900000000184e+02 +4.950069700000000239e+02 2.153823399999999992e+02 5.528816199999999981e+02 3.004902900000000159e+02 +3.173656899999999723e+02 3.218481499999999755e+02 2.990111999999999739e+02 3.575231400000000122e+02 +3.194274800000000027e+02 2.037354400000000112e+02 3.261338900000000081e+02 3.078924200000000155e+02 +2.290635804859078348e+02 1.900059737384212326e+02 2.118688394443077527e+02 2.747577352684748462e+02 +2.183714045427048518e+02 1.799618084584426754e+02 2.000627579024160241e+02 2.638261782852417809e+02 +2.433198150768451455e+02 1.994021283551753640e+02 2.306711174554686181e+02 2.909364396036597782e+02 +2.373870513700753691e+02 2.957831966989249395e+02 9.725213989465640907e+01 3.038423278124165563e+02 +5.110142288247588453e+02 1.927576376379860790e+02 6.016781362971241833e+02 2.409929073642294952e+02 +4.989069086051293880e+02 1.915706454595910202e+02 5.583054420985690740e+02 2.691103780860514689e+02 +2.920441067044379224e+02 3.204059275408216081e+02 2.264041560013508274e+02 3.558669926242116617e+02 +3.894701654697320805e+02 2.555707881893844444e+02 4.243882654767081135e+02 3.327558240056685008e+02 +3.573868057796908602e+02 3.125699484897768912e+02 4.077506998783375707e+02 3.483199982751119705e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..214a9e6dc7aa32cf5b5de3fbddc9a24173f7062f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/crossval_errors.txt @@ -0,0 +1,27 @@ +9.619610733823423521e-02 +2.222023173583103528e+00 +3.603161857524858069e-01 +3.100090121775084251e-01 +3.417931173654581212e+00 +2.637932478789538404e+00 +4.066589147187809616e-01 +5.500156820911586308e+00 +3.241004267474435263e+00 +4.430177782216134119e+00 +1.294956167617228537e+00 +4.710739887513437196e+00 +2.130034143130785651e-01 +1.209101134344288342e+00 +5.132132907292509927e+00 +1.079070568921344231e+00 +9.880345585358989879e+00 +3.341722142539176765e+00 +2.132997868911963302e-01 +3.638139721764838130e-01 +8.544621529274050165e-01 +8.368403497652689538e+00 +1.494648836718671658e+00 +1.149710096719515517e+00 +7.077658455701806517e+00 +2.806159435480652053e+00 +8.022198499442469100e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/01.png new file mode 100644 index 0000000000000000000000000000000000000000..b2b3c9292b9ab6fd8e3be009edb3c0ff36a2dec7 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24ad7fea46687df48cfff5c5d8b7fec4c8394a3fe8aba5958443ad1e30840bd0 +size 701851 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/02.png new file mode 100644 index 0000000000000000000000000000000000000000..92f4b7dfa6ed91009b747063dd129aa50fde5236 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ddfd13cf66c2d2ee29222e7fd17a65bbe1314d0d9f2633bb6890f8500d6a0a1 +size 798718 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..261f92cdaba758ab740a8394f9a2e1db3dc4100d --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/corrs.txt @@ -0,0 +1,20 @@ +5.382691700000000310e+02 3.201234299999999848e+02 8.401269600000000537e+02 2.890135200000000282e+02 +5.136501799999999776e+02 3.198889699999999721e+02 7.974540399999999636e+02 2.904203200000000038e+02 +1.237602799999999945e+02 3.225853299999999990e+02 2.599307699999999954e+02 3.111280600000000049e+02 +1.537719999999999914e+02 3.224680999999999926e+02 2.903263200000000097e+02 3.092262000000000057e+02 +1.461141399999999919e+02 3.159998600000000124e+02 3.083832499999999754e+02 3.169733100000000263e+02 +1.357277599999999893e+02 4.126277600000000234e+02 2.628877299999999764e+02 4.291835699999999747e+02 +3.249277599999999779e+02 3.920216699999999719e+02 4.928279400000000123e+02 4.026567099999999755e+02 +3.176001600000000167e+02 3.420965499999999793e+02 4.826547699999999850e+02 3.305104000000000042e+02 +3.671542200000000093e+02 3.417739000000000260e+02 5.515446899999999459e+02 3.293574800000000096e+02 +3.581272599999999784e+02 3.794777100000000019e+02 5.398474899999999934e+02 3.856578900000000090e+02 +5.280300999999999476e+02 3.967701400000000262e+02 8.180482600000000275e+02 4.170241500000000201e+02 +1.984237099999999998e+02 3.644407299999999736e+02 3.422599599999999782e+02 3.659925299999999879e+02 +2.103922200000000089e+02 3.919926199999999881e+02 3.558881600000000276e+02 4.075805100000000039e+02 +4.188388800000000174e+02 3.609237299999999777e+02 6.433891999999999598e+02 3.599419500000000198e+02 +5.105621300000000247e+02 3.594185200000000009e+02 7.931084399999999732e+02 3.569712799999999788e+02 +3.755597900000000209e+02 3.891606100000000197e+02 5.768200500000000375e+02 4.055832399999999893e+02 +3.593878532316947485e+02 3.070974129196879403e+02 5.717608921013747931e+02 2.883844101883552185e+02 +3.098401408430057700e+02 3.068076602156605190e+02 5.045165838838461809e+02 2.903391865900275661e+02 +3.417129382860220517e+02 2.931892831263717198e+02 5.522131280846514301e+02 2.680547356109628936e+02 +3.379461531336655753e+02 3.105744453680169954e+02 5.447849777582964634e+02 2.938577841130378374e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..499b15ab86c373201270fa389256c34de410fc1e --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/crossval_errors.txt @@ -0,0 +1,20 @@ +1.045113346684135858e-01 +5.782519597789786969e-01 +1.288780350433794464e+00 +7.571290543942928997e-01 +2.320711714614160037e+00 +4.092129552015536298e+00 +1.144462481685707411e+00 +7.868003583494459496e-01 +1.517594664920262293e+00 +5.049213364468354559e-01 +5.365094462105434170e-01 +2.255850096218184664e+00 +6.997883488536579266e-01 +2.091201714459292038e+00 +3.756746109812822088e+00 +2.731580148502795180e-01 +2.652923335284723461e+00 +3.001299658395643888e-02 +1.538170234708976736e+00 +6.465584911931659962e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/01.png new file mode 100644 index 0000000000000000000000000000000000000000..0fe8b56b9d4182e98e519a53cf45c89cd81ca05f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c23e0225a23cca98354330e16bb25b087521cfdc02e260157917a73d394943b5 +size 455694 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/02.png new file mode 100644 index 0000000000000000000000000000000000000000..eed716c239494c2a8d858ac87f3aa15cc3daf798 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a45c8c303c9519f388b4d0a020cc42ec1568f60fb264e1bab1c37ba36071a47 +size 440407 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..98548c1c0fed67b5e082bbe57de6c0f0cd217708 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/corrs.txt @@ -0,0 +1,28 @@ +2.939664000000000215e+02 3.045692700000000173e+02 4.071241999999999734e+02 2.350125999999999920e+02 +2.789837099999999737e+02 2.884094099999999798e+02 3.715412999999999784e+02 2.365269499999999994e+02 +5.225245799999999718e+01 2.785376300000000072e+01 1.143709000000000060e+02 2.453592199999999934e+02 +3.934123499999999751e+01 8.992861499999999353e+01 1.233397599999999983e+02 2.691005200000000173e+02 +3.389435199999999782e+02 4.382933800000000133e+02 4.989929099999999949e+02 3.114054800000000114e+02 +3.443217300000000023e+02 3.877132700000000227e+02 4.141194600000000037e+02 2.683777499999999918e+02 +3.975691499999999792e+02 2.972368299999999977e+02 4.135139300000000162e+02 1.597436500000000024e+02 +2.289945999999999913e+02 3.682853900000000067e+02 3.708185399999999845e+02 3.206088100000000054e+02 +1.120068500000000000e+02 2.817852100000000064e+02 2.133036899999999889e+02 3.207066800000000057e+02 +1.708717900000000043e+02 1.632734199999999873e+02 2.057512999999999863e+02 2.416852199999999868e+02 +2.271479199999999992e+02 2.101553299999999922e+02 2.645968799999999987e+02 2.313095500000000015e+02 +3.922296799999999735e+02 3.802393799999999828e+02 5.436231699999999591e+02 2.153056500000000142e+02 +3.574577300000000264e+02 3.783975399999999922e+02 4.509628799999999842e+02 2.486417500000000018e+02 +3.725721800000000030e+02 3.528192399999999793e+02 4.507284099999999967e+02 2.180116999999999905e+02 +3.500021800000000098e+02 4.498192399999999793e+02 5.468960200000000214e+02 3.127756200000000035e+02 +1.992346400000000131e+02 3.168652799999999843e+02 3.004045600000000036e+02 3.025635300000000143e+02 +1.692256899999999860e+02 2.660894400000000246e+02 2.471184000000000083e+02 2.895495999999999981e+02 +3.875780399999999872e+02 3.331240500000000111e+02 4.582313399999999888e+02 1.899929500000000075e+02 +3.346446199999999749e+02 4.545085700000000202e+02 5.133672900000000254e+02 3.283676500000000260e+02 +1.324547699999999963e+02 2.970682899999999904e+02 2.345914400000000057e+02 3.211758500000000254e+02 +1.349730299999999943e+02 2.913113200000000091e+01 1.460674099999999953e+02 2.118581499999999949e+02 +3.387018081266798504e+02 4.211557926482452103e+02 4.577099168638121682e+02 2.986015155657120772e+02 +3.576395934174669833e+02 4.358770839507731694e+02 5.373521867041577025e+02 2.962761938185487338e+02 +2.004951893092595014e+02 1.164639536435136620e+02 2.074144164015886247e+02 2.075142798410796274e+02 +2.348958052953859124e+02 1.630370952862693912e+02 2.508620034821227591e+02 2.048542234892101987e+02 +2.470683309520152591e+02 1.805020234023028252e+02 2.703690833958319217e+02 2.037458666759312678e+02 +2.486560516898364881e+02 2.022008734858594607e+02 2.834476937925233528e+02 2.126127211821627156e+02 +1.973197478336170718e+02 1.947915100426937443e+02 2.373400503601197897e+02 2.407649842394475854e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..45575c3a43a37bd1cce73af0af5e68337390ba23 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/crossval_errors.txt @@ -0,0 +1,28 @@ +4.035863378755469855e-01 +9.721997233306670649e-01 +1.633407097610160008e+00 +2.340841776830394072e-01 +3.307562772547593677e-01 +7.631667399378996297e-01 +6.225453458632575376e+00 +1.546114169351287426e+00 +6.222134223220379123e-01 +2.352856144041039210e-02 +3.672563650444910649e-02 +1.676348197120265171e+00 +6.024728066961207995e-01 +1.600202597754185385e+00 +1.452115718380216469e+00 +2.531982995760208022e-01 +6.108426706884196866e-01 +2.584720721075003613e+00 +1.257587290875826991e+00 +1.795838969385847195e+00 +4.973690863625047420e+00 +2.269005600197053329e+00 +2.040850506195554637e+00 +1.557553158164507678e+00 +2.292730519237752285e-01 +7.590197588485040336e-01 +6.090837038593645003e-01 +5.670417249077306376e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/01.png new file mode 100644 index 0000000000000000000000000000000000000000..09954993abacdecd352ae6e8a1a0371ff7c41024 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33b085670bc7a1f0c584de21d7ddc720bb4c758a00b58ec781567cccc73b6ebe +size 560605 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/02.png new file mode 100644 index 0000000000000000000000000000000000000000..20aa465c068c1365a8855a88db7c6f2a72361231 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07f191bb344878e5f8b78f8b41cf81d5f34363227d44cea390aa6e2b9aa271d0 +size 671791 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..59b10f1baad7423b4d2401478e040780edbff085 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/corrs.txt @@ -0,0 +1,29 @@ +3.627090400000000159e+02 2.702030500000000188e+02 1.724055199999999957e+02 2.529717200000000048e+02 +4.603232100000000173e+02 2.732317499999999768e+02 2.515151700000000119e+02 2.542031900000000064e+02 +3.189187699999999950e+02 2.363893700000000138e+02 3.194758899999999926e+02 2.518294699999999864e+02 +5.270825499999999693e+02 2.679455300000000193e+02 4.677583599999999819e+02 2.706653000000000020e+02 +5.266039399999999659e+02 3.173941199999999867e+02 4.673281600000000253e+02 3.058190299999999979e+02 +3.091980899999999792e+02 3.267136399999999981e+02 3.137508300000000077e+02 3.146115300000000161e+02 +6.601129399999999805e+02 2.463251500000000078e+02 5.610241800000000012e+02 2.537342500000000030e+02 +2.104019000000000119e+02 1.983037699999999859e+02 2.456234200000000101e+02 2.264652399999999943e+02 +3.583122700000000123e+02 1.349441999999999950e+02 1.771539899999999932e+02 1.422758300000000133e+02 +4.208415299999999775e+02 1.127763800000000032e+02 2.445876800000000060e+02 1.253243800000000050e+02 +4.750268300000000181e+02 2.091580500000000029e+02 4.303675099999999816e+02 2.300800999999999874e+02 +6.398789600000000064e+02 2.135752200000000016e+02 5.467690999999999804e+02 2.309007400000000132e+02 +4.855788600000000201e+02 3.579385100000000079e+02 4.400204100000000267e+02 3.348551899999999932e+02 +3.184401599999999917e+02 3.619534899999999880e+02 3.223776799999999980e+02 3.390755899999999770e+02 +3.963968300000000227e+02 2.484256699999999967e+02 2.047155700000000138e+02 2.361374999999999886e+02 +4.147153099999999881e+02 1.501392199999999946e+02 2.245483999999999867e+02 1.525546600000000126e+02 +6.291612800000000334e+02 2.689328499999999735e+02 5.399007500000000164e+02 2.703921000000000276e+02 +7.098627300000000560e+02 3.155754400000000146e+02 5.514003300000000536e+02 2.971717800000000125e+02 +7.261398299999999608e+02 2.930355299999999943e+02 6.026451100000000451e+02 2.853398599999999874e+02 +5.579149099999999635e+02 2.668904299999999807e+02 4.895637500000000273e+02 2.692585000000000264e+02 +8.792362759200152311e+01 3.270246234764344990e+02 1.419145446904389019e+02 3.147235164869259165e+02 +2.207784042665544177e+01 2.922096095729181684e+02 4.256379174071378202e+01 2.857788089346437914e+02 +8.716677946366421281e+01 2.853979764178823757e+02 1.426968340837437381e+02 2.849965195413388983e+02 +1.136564639554701017e+02 2.687473175944615491e+02 1.653832264895864910e+02 2.740444680350699969e+02 +3.570110673672705559e+01 2.384733924609690803e+02 9.497718109214360993e+01 2.529226544158370871e+02 +6.768215978607523766e+02 2.604095934004204196e+02 5.730242994595705568e+02 2.641227329797326320e+02 +7.547998898712631899e+02 3.500846292125078776e+02 6.176350265001917705e+02 3.257280227024951955e+02 +7.415435802294763334e+02 3.126550490474626827e+02 6.105539587159662460e+02 2.981118583440154453e+02 +7.470020606702121313e+02 2.526117641993693610e+02 6.225917739491496832e+02 2.577497719739296258e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..aab023e8dd2cef9877da7fbebae616534e2d4b0c --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/crossval_errors.txt @@ -0,0 +1,29 @@ +8.374156557111057664e-01 +1.328765030518358758e-01 +9.603049900896819535e-01 +1.439087407591574275e+00 +7.279438076244393319e-01 +9.105730160893195091e-01 +5.490532157496890164e-01 +7.626828820733734526e-02 +1.322878606601286089e+00 +8.847866436993027106e-01 +1.026845569665250979e+00 +1.485877897448205642e-01 +1.937133625975002438e-01 +1.459832159161253751e+00 +2.039198747794767730e+00 +1.901781580371950442e+00 +7.648652934105401036e-01 +7.588993994218221628e-01 +8.633612883264278892e-01 +1.994298324411831080e+00 +4.200922354091934374e-01 +2.263208001745716125e+00 +1.945820659482044857e+00 +1.067577664256973868e+00 +1.035513580979050774e+00 +4.557118879492584318e-01 +1.209804537310066763e+00 +2.390102886096314716e+00 +1.692217931017956589e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/01.png new file mode 100644 index 0000000000000000000000000000000000000000..b8c2731f8b256d6eb9486debd49e4424f5ae793e --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7cde3fbbd4a580b626e23b3d893fef722a0f49b248084d97e62a2373a515d6c3 +size 33920 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/02.png new file mode 100644 index 0000000000000000000000000000000000000000..6cf80e281f3483e87a7c64f52d5e6e4ac77feae0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a7d3befa0dd5d7f385133211d360842177e899c75ea1b3e86fbccf1d6340efd +size 571399 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..610b18654ece85e6dfff13d32904b048f54888f5 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/corrs.txt @@ -0,0 +1,16 @@ +9.781961300000000392e+01 1.640526199999999903e+02 3.646632700000000114e+02 3.071947599999999738e+02 +5.911208700000000249e+01 1.617282899999999870e+02 2.837176600000000235e+02 3.020257799999999975e+02 +9.678194299999999828e+01 1.289415600000000097e+02 3.626208500000000186e+02 2.492279200000000117e+02 +1.213090100000000007e+02 4.938435900000000345e+01 3.456403599999999869e+02 9.959778400000000431e+01 +1.372421600000000126e+02 9.650363299999999356e+01 3.866751199999999926e+02 1.909491399999999999e+02 +1.627041900000000112e+02 7.731990999999999303e+01 4.574601599999999735e+02 1.513018099999999890e+02 +1.630471799999999973e+02 9.672647999999999513e+01 4.436449600000000260e+02 1.959417599999999879e+02 +1.267521700000000067e+02 1.210687199999999919e+01 3.491089400000000182e+02 1.299489400000000039e+01 +7.839227300000000120e+00 1.478580799999999940e+01 1.648144100000000094e+02 5.922769000000000261e+01 +3.307260300000000086e+01 1.982403400000000033e+02 2.107729699999999866e+02 3.684140499999999747e+02 +8.604231599999999647e-01 2.169522200000000112e+02 1.594700899999999990e+02 3.949990399999999795e+02 +5.562704200000000299e+01 8.441304300000000183e+01 2.715727800000000229e+02 1.740955600000000061e+02 +1.220638988325962941e+02 1.185185435128226459e+02 3.517891147119119069e+02 2.390017363973832971e+02 +2.390941223160494644e+02 1.172779404370260750e+02 5.516424743042626915e+02 2.473289597137311375e+02 +1.878158618498014221e+02 1.346463834981746004e+02 5.183335810388708751e+02 2.689797403362358637e+02 +9.644884478416841489e+01 2.165869733815735287e+02 2.693040235684538857e+02 4.296418009957009190e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..81a8f649fae74130a4a1e12c0b470148f748c8f0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/crossval_errors.txt @@ -0,0 +1,16 @@ +4.027252550598703351e-01 +3.191271567188723068e+00 +5.821331243373673026e-01 +2.905143316101345352e+00 +2.687513676074431590e+00 +4.783819982168825424e-01 +5.648388513287787127e+00 +8.213834063298731891e+00 +7.330748034364480858e+00 +1.110741198399469026e-01 +2.536472718509040103e+00 +3.155581166483636846e+00 +3.507286154532603728e+00 +1.179596267993104775e+01 +4.008971269381730984e+00 +9.571523717347277582e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/01.png new file mode 100644 index 0000000000000000000000000000000000000000..cfefe5bcf68ced8df5f98eef309c47ec288e0dc0 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76e1af809211af74946975764fb4ed3eb450018f4400ddd220a5b2b8c28f3325 +size 39460 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/02.png new file mode 100644 index 0000000000000000000000000000000000000000..40b296d4a5a7772c1c77b7bc1db799682897ca30 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3587f375b8f16a462cc146d69470725708d2e6c3e9bc008828c6226e90404ae +size 571656 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e1cfab21473a292a75450a8d91fd13576c763f6 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/corrs.txt @@ -0,0 +1,23 @@ +2.031673099999999863e+02 1.511227499999999964e+02 5.470917500000000473e+02 3.812720199999999977e+02 +1.278246900000000039e+02 1.582333099999999888e+02 3.950975900000000252e+02 3.912282000000000153e+02 +6.981330699999999467e+01 1.378685399999999959e+02 2.838833099999999945e+02 3.628750400000000127e+02 +2.062450300000000070e+01 1.665811500000000080e+02 2.064353399999999965e+02 4.181154300000000035e+02 +5.874876700000000085e+01 9.520412799999999720e+01 2.468054300000000012e+02 2.787631000000000085e+02 +4.306461399999999884e+01 1.194967299999999994e+02 2.168206900000000132e+02 3.224748599999999783e+02 +4.306461399999999884e+01 9.546390700000000606e+01 2.136284100000000024e+02 2.799354299999999967e+02 +1.363042900000000088e+02 2.272685700000000111e+02 4.315301600000000235e+02 5.132182599999999866e+02 +1.249141999999999939e+02 2.089112900000000081e+02 4.102798199999999724e+02 4.857932799999999816e+02 +7.821797799999999690e+00 1.875146499999999889e+02 1.872661400000000071e+02 4.604397599999999784e+02 +1.791764400000000137e+02 4.797713999999999857e+01 4.946746400000000108e+02 2.018916099999999858e+02 +1.969959000000000060e+02 4.750820600000000127e+01 5.294929300000000012e+02 2.007192699999999945e+02 +6.168053900000000311e+01 4.406037500000000051e+01 2.541400899999999865e+02 1.951090299999999900e+02 +2.074171461111978942e+02 1.130778721436135328e+02 5.534743734514790958e+02 3.064526860507694437e+02 +2.077209801180918589e+02 9.697466977823438583e+01 5.548043328588382792e+02 2.798534979035860601e+02 +1.779452474424851403e+02 6.780660511641553967e+01 4.956211392313551869e+02 2.379597765717722382e+02 +1.803759194976366871e+02 9.059415563346149725e+01 4.969510986387143703e+02 2.665539038299943400e+02 +9.408706153975595043e+01 8.176509711761033827e+01 3.220614365709834601e+02 2.578700518258719399e+02 +1.056327538017258973e+02 8.176509711761033827e+01 3.453357261997688852e+02 2.558751127148332216e+02 +1.065442558224077345e+02 7.052323886253432761e+01 3.566403811623218303e+02 2.445704577522802765e+02 +8.832415683445049126e+01 1.547573300627117305e+02 3.291543664475755122e+02 3.951877991166010133e+02 +7.327830040831864267e+01 1.730162867597605043e+02 2.886199438891764544e+02 4.214042890366270626e+02 +7.969089821058844336e+01 1.811777748717402687e+02 3.093775883080778044e+02 4.366993954505543343e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..4789cb2b13b2f53035fbcadeca812fff6cceef12 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/crossval_errors.txt @@ -0,0 +1,23 @@ +6.956774605531451883e+00 +2.860879616904584033e+00 +1.697843297093116988e+00 +5.438113664493887889e+00 +4.995418497168292449e-01 +1.942702354772605444e+00 +4.400681895413596084e-02 +3.868755083381945958e+00 +1.391979085576086628e-01 +7.054838917641514939e+00 +7.125814072624173656e-01 +1.492414840109906216e+00 +3.395624676571964873e+00 +3.332879502809552030e-01 +2.233287232798145894e+00 +6.365794263520651031e-01 +9.102453491498465610e-01 +1.003026023461308425e+00 +2.251558963684663073e-01 +3.920756389386108598e+00 +1.955627908840252616e+00 +2.169390693519445357e+00 +4.825353733967688852e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/01.png new file mode 100644 index 0000000000000000000000000000000000000000..051db61918b367bb72469d3107bbd28172e95a22 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24a3f33faf8b550c9a1d242cfc309ab8cb137f228796cda303aa5b0ed5b77522 +size 40390 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/02.png new file mode 100644 index 0000000000000000000000000000000000000000..e6a7a888167d82497f909ccd26bded4755f0e491 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f16337a7f7eaad8a67af97b15e5f4d77ca0641f62719cbf87a3934b2b5d824ce +size 589826 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..a11078c55b58daf033d7ba3cb1e16ae9e664a9d1 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/corrs.txt @@ -0,0 +1,28 @@ +1.628443699999999978e+02 2.167013399999999876e+02 2.607572299999999927e+02 3.298615899999999783e+02 +1.470178799999999910e+02 2.174047400000000039e+02 2.319178400000000124e+02 3.297443599999999719e+02 +1.615533499999999947e+02 1.005133799999999979e+02 4.440848500000000172e+02 1.440032600000000116e+02 +1.929841299999999933e+02 4.930592200000000247e+01 4.686370900000000006e+02 3.985307399999999944e+01 +1.940406900000000121e+02 1.124077800000000025e+02 4.498787300000000187e+02 1.653708399999999870e+02 +1.767331900000000076e+02 1.142835100000000068e+02 4.307095299999999725e+02 1.706269800000000032e+02 +1.740266399999999862e+02 2.291979100000000074e+02 4.231765100000000075e+02 3.930738499999999931e+02 +1.904393000000000029e+02 2.389282800000000009e+02 4.407615099999999870e+02 4.050316500000000133e+02 +4.218699300000000108e+01 9.592236900000000333e+01 1.025002499999999941e+02 1.362164099999999962e+02 +9.343587299999999374e+00 7.190201600000000326e+01 3.459909999999999819e+00 9.663790299999999434e+01 +6.077387399999999928e+01 1.963945599999999914e+01 8.101543499999999653e+01 6.908445300000000344e+00 +4.348383400000000165e+01 1.419482800000000111e+02 1.556304399999999930e+02 2.232398200000000088e+02 +1.144621499999999941e+02 5.671926500000000004e+01 1.966845700000000079e+02 6.646710500000000366e+01 +1.264752399999999994e+02 7.190783600000000320e+01 2.193707699999999932e+02 9.096125299999999925e+01 +1.480104100000000074e+02 5.954450400000000343e+01 2.564196000000000026e+02 6.945707199999999659e+01 +1.214080100000000044e+02 9.816435199999999384e+01 3.516977499999999850e+02 1.410918000000000063e+02 +1.333015800000000013e+02 2.099018100000000118e+02 2.282836100000000101e+02 3.237654600000000187e+02 +1.450249100000000055e+02 2.088467100000000016e+02 2.544266299999999887e+02 3.232965300000000184e+02 +1.690577299999999923e+02 2.489404999999999859e+02 2.669705999999999904e+02 3.816787100000000237e+02 +2.788453099999999907e+01 9.568790199999999402e+01 7.600552500000000578e+01 1.351613099999999861e+02 +1.189651512583989756e+02 4.266819208086550930e+01 2.160874152713540184e+02 3.704794679781866762e+01 +1.507345979079198912e+02 4.911762861873818053e+01 2.834910754242177404e+02 5.065235526903887830e+01 +1.311474202743806927e+02 3.430781138362316085e+01 2.222712373037268776e+02 2.344353832659845693e+01 +3.044997022037620837e+01 7.979086484477357999e+01 1.416168022564780244e+02 1.083209577219725475e+02 +2.806449069218989933e+01 1.184886438575735923e+02 1.416168022564780244e+02 1.817617464555121956e+02 +2.089269163050161637e+00 1.182235905766639945e+02 8.699142220673780912e+01 1.799409004538541694e+02 +7.655388062151530448e+00 7.634517219294892243e+01 1.977627943726929516e-01 1.010375737153404998e+02 +3.469082271492963088e+01 9.754943466571606336e+01 8.820531954117637952e+01 1.380614424157199664e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..930b7074944da86e735567c7ded9fac73bdebaba --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/crossval_errors.txt @@ -0,0 +1,28 @@ +7.752240801306259366e-01 +1.148758241413943004e+00 +1.768738052957149032e+00 +1.568654570431663009e+00 +2.271406557678926907e-01 +2.449004580308638268e-01 +2.812550544505385108e+00 +2.360283068848247012e+00 +8.507221766035513166e-01 +1.476137889129510050e+00 +3.842955461712121989e-01 +3.790487856614814244e-01 +1.834805512081233880e+00 +1.705524880090546203e-02 +1.197058223742027971e+00 +1.595218938876604153e-01 +4.930223213755020595e-01 +2.767083777630560126e-01 +1.538363609208045046e+00 +3.359096424522557411e-01 +3.107673526549437959e+00 +2.830308237137624694e+00 +3.095121535859486794e+00 +2.387996241400114794e+00 +2.596366077667401351e-01 +1.379228150193078539e+00 +1.946997708432720353e+00 +5.931413354392515158e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/01.png new file mode 100644 index 0000000000000000000000000000000000000000..c7148ac014161852700da15270e74668497c8b8b --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:989fbd9964db755bb62ff02415a77d5fec359750c4a4674a6d9f30e047cd5f56 +size 42553 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/02.png new file mode 100644 index 0000000000000000000000000000000000000000..8be58f1a21e8e277f17e0594ccd7e1c291439fb5 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bcb2a79c6312164966841bcf36b8f9f01af95fd70b135f118740a0bd6de8ff9e +size 554068 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..31d982d310eec85faadee731bf733a144e7588d5 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/corrs.txt @@ -0,0 +1,26 @@ +1.163067999999999991e+01 1.006243099999999941e+02 4.788471099999999865e+02 2.956166799999999739e+02 +8.383898000000000650e+01 1.648345700000000136e+02 6.089369799999999486e+02 4.227757199999999784e+02 +6.257121200000000272e+01 1.188440099999999973e+02 5.714095300000000179e+02 3.278976799999999798e+02 +5.359749899999999911e+01 1.795072900000000118e+02 5.490653499999999667e+02 4.416253699999999753e+02 +4.060295299999999941e+01 1.382418600000000026e+02 5.267889499999999998e+02 3.491276199999999790e+02 +5.382905699999999882e+01 1.715750299999999982e+02 5.498956699999999955e+02 4.306991699999999810e+02 +9.625899200000000633e+01 1.509786799999999971e+02 5.973007999999999811e+02 3.364266600000000267e+02 +1.111350900000000053e+02 1.272402799999999985e+02 6.381398199999999861e+02 3.165250100000000089e+02 +1.594964899999999943e+02 5.475679999999999836e+01 7.662270399999999881e+02 1.909491399999999999e+02 +1.576062100000000044e+02 1.025320300000000060e+02 7.599642300000000432e+02 2.903702299999999923e+02 +7.635349899999999934e+01 1.282101399999999956e+02 5.923780299999999670e+02 3.443898800000000051e+02 +1.604430800000000090e+02 7.735484999999999900e+01 7.678683099999999513e+02 2.363184300000000064e+02 +1.629253400000000056e+02 1.936663399999999911e+02 6.783636400000000322e+02 3.636074899999999843e+02 +3.278412300000000101e+01 1.553686600000000055e+02 4.981625900000000229e+02 3.797770399999999995e+02 +1.694398599999999888e+01 1.558463199999999915e+02 4.681584799999999973e+02 3.737583700000000135e+02 +1.595255699999999877e+02 1.682267200000000003e+02 7.001690300000000207e+02 3.496567299999999818e+02 +1.689042300000000125e+02 1.947214400000000012e+02 6.910248299999999517e+02 3.653659900000000107e+02 +1.670709897623286508e+02 3.286760692285372443e+01 7.881821280378179608e+02 1.521775730514374345e+02 +1.598054208230585118e+02 6.878027625124609301e+01 7.676396469214458875e+02 2.198916774720711373e+02 +3.880017453908246239e+01 1.591056886131800923e+02 5.224260410634325353e+02 3.925349736709622448e+02 +4.383920758304722654e+01 1.587697530769157765e+02 5.382471696556165170e+02 3.918470985147803276e+02 +5.324540259844809498e+01 1.587697530769157765e+02 5.568197988725282812e+02 4.069803519507824490e+02 +1.285673750694253954e+02 1.236897702190242683e+02 6.630219832663985926e+02 2.996768242875007218e+02 +1.208212523026747789e+02 1.600470593496146421e+02 6.767368812378181246e+02 4.147277185279446599e+02 +7.859565859050952952e+01 1.699149970106097953e+02 6.024193888738506075e+02 4.321538891512198575e+02 +3.315724796546213327e+01 1.662432062530304790e+02 5.004250372846813093e+02 3.988391511949585038e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..11a7637176f2973f3cd07984b07696dc68999c64 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/crossval_errors.txt @@ -0,0 +1,26 @@ +7.907225210316218167e+00 +2.535624236718984736e+00 +3.086222092834641773e-01 +1.874681247790922489e+00 +6.968237522558005992e+00 +5.594632224115625441e+00 +9.166892174461624521e+00 +5.255621098281488379e+00 +4.583878768840595752e-01 +2.145340407892266654e+00 +2.464833084486542703e+00 +2.055031501699763563e+00 +2.236388676391594998e+00 +7.235284117576637364e+00 +7.775682560729027415e+00 +5.285462960500259655e+00 +2.039091324975429664e+00 +2.681807750177786165e+00 +1.585768370720183507e+00 +3.566932614906344501e+00 +1.104428367277808753e+01 +2.088649663206282092e+00 +1.125499910202593234e+01 +7.554842478097656411e+00 +1.402237521628181938e+00 +3.734888943104968106e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/01.png new file mode 100644 index 0000000000000000000000000000000000000000..a923b184c80023980d26b184a600d9c1f37bd5b2 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9478623f10f039e55d1a34d0eb99a618cd5ffcc4f127b232b3d835db352acd64 +size 43062 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/02.png new file mode 100644 index 0000000000000000000000000000000000000000..fe6e8ac17368220990dd082d6e407a5200df6b45 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf74fb4e2859475a9181a31a2dbca5b36e1fbb890f1b3dfca1beaa341a2d5611 +size 770313 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..2cc8d21fbc51a953aac5bd97e79c66ae4b9d7d68 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/corrs.txt @@ -0,0 +1,31 @@ +9.897087299999999743e+01 4.934491899999999731e+01 3.632952000000000226e+02 1.367832099999999969e+02 +1.261303699999999992e+02 5.910536499999999904e+01 3.965527999999999906e+02 1.532947800000000029e+02 +1.599906200000000069e+02 7.385744900000000257e+01 4.414175299999999993e+02 1.768897900000000050e+02 +1.897854499999999973e+02 1.028151900000000012e+02 4.796774199999999837e+02 2.206693400000000054e+02 +1.413513900000000092e+02 8.731513800000000458e+01 4.165426299999999742e+02 1.930496599999999887e+02 +8.060585000000000377e+01 1.388418299999999874e+02 3.286501700000000028e+02 2.480362199999999859e+02 +7.987434899999999516e+01 1.242497299999999960e+01 3.395829600000000141e+02 9.155710200000000043e+01 +1.868148799999999881e+02 4.807181800000000038e+01 4.802248599999999783e+02 1.500509800000000098e+02 +1.995904700000000105e+02 1.538564800000000048e+02 4.841914199999999937e+02 2.815089199999999892e+02 +9.005026000000000863e+00 5.210730299999999815e+00 2.472646799999999985e+02 7.471320400000000461e+01 +5.654654400000000081e+00 1.297964200000000119e+02 2.329611899999999878e+02 2.294531900000000064e+02 +4.750520800000000321e+01 1.867672300000000121e+02 2.826131100000000060e+02 3.014105599999999754e+02 +2.378445499999999981e+02 8.430482899999999802e+01 5.434962500000000318e+02 2.020185199999999952e+02 +1.966315299999999979e+02 6.298087300000000255e+01 4.909328600000000051e+02 1.670217900000000100e+02 +2.435347500000000025e+02 1.771240099999999984e+02 5.369892899999999827e+02 3.132715299999999843e+02 +2.639605799999999931e+02 1.931162699999999859e+02 5.629881000000000313e+02 3.345121899999999755e+02 +1.291009400000000085e+02 1.613778599999999983e+02 3.893327499999999759e+02 2.781672500000000241e+02 +1.546802099999999882e+02 1.858129000000000133e+02 4.204801499999999805e+02 3.131155600000000163e+02 +1.518268700000000138e+02 2.105879899999999907e+02 4.132891500000000065e+02 3.402469399999999951e+02 +1.056154099999999971e+02 2.148762299999999925e+02 3.533884699999999839e+02 3.404814099999999826e+02 +1.066277800000000013e+02 7.896014099999999303e+01 3.683018799999999828e+02 1.776588500000000010e+02 +2.017403999999999940e+02 1.114033000000000015e+02 4.912942400000000021e+02 2.288863900000000058e+02 +3.921801490628814690e+01 7.235750264493356099e+01 2.786385103920094934e+02 1.611776639627567533e+02 +3.970357384006362622e+01 4.613732022105824626e+01 2.800882859621938223e+02 1.271079380634251947e+02 +6.200007409556292259e+00 5.147846849258843349e+01 2.380447944268485116e+02 1.300074892037938525e+02 +1.396895034996379081e+01 1.014910386714616948e+02 2.460185600628623206e+02 1.952473898620882551e+02 +4.941475251557298520e+01 7.184338185676199373e+01 2.931362660938526687e+02 1.604101363703597372e+02 +4.504472211159377082e+01 3.865131128196424015e+00 2.938611538789447195e+02 7.704804108476133706e+01 +2.638002771336840624e+02 7.635581717826380554e+01 5.788273089375950349e+02 1.954089779812106826e+02 +2.352291152034932509e+02 1.095780984924392101e+02 5.360306840245516469e+02 2.303220140944829382e+02 +1.694489982014259510e+02 3.914686210638734565e+01 4.594472499696317982e+02 1.370439269645612512e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..f559b819b0225783996edbd143c240f10c4aff0d --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/crossval_errors.txt @@ -0,0 +1,31 @@ +3.584996310465787595e+00 +4.946349073889345083e-01 +1.107964495161395613e+00 +5.027808283948981272e-01 +1.233645956010521960e+00 +1.140904304662818447e+00 +1.746155141624057272e+00 +6.075603257546631220e-01 +8.903860399682645976e-01 +1.231921832177196446e-01 +4.462442879723696465e-01 +1.348039545561001029e+00 +1.267108840243132750e+00 +1.770145183486460772e+00 +1.586913035140011363e+00 +7.921288124968347555e-01 +7.384951230528175037e-01 +1.349153872928968045e+00 +5.354941394430628998e-01 +7.480195637740312264e-01 +1.645405748374197286e+00 +6.062212534539174191e-01 +2.481283471472908175e+00 +3.224114288960418850e+00 +9.488628321632434082e-01 +6.451342801452116804e-01 +1.852461708152915998e-01 +1.880127786478289709e+00 +2.137403187048483755e+00 +1.837261225400464437e-01 +7.754601935878026042e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..61b34e0771a875eb567d8c70906d04413e0987ca Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/.DS_Store differ diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/01.png new file mode 100644 index 0000000000000000000000000000000000000000..bd7292fa5d275ccc539d43881c5ac5647f3da656 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da934d6a058f970ddcee0cb0434cabd06e0f23588c0466b925de5b2b290d8ce4 +size 352843 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/02.png new file mode 100644 index 0000000000000000000000000000000000000000..ed05585af37639e01b384253800f6a086740abac --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0371595a5ccab153f60b0e00bb5df8056c21722b53800dc1b484cb64945e869 +size 602365 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b6e2dfb82a70a2658388de80ad1e7613c9196a2 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/corrs.txt @@ -0,0 +1,40 @@ +1.169112800000000050e+02 4.929016299999999973e+01 3.541015600000000063e+02 1.477567500000000109e+02 +1.563096199999999953e+02 6.461383399999999710e+01 3.742968000000000188e+02 1.550746599999999944e+02 +1.300493600000000072e+02 6.308980100000000135e+01 3.624562399999999798e+02 1.543712600000000066e+02 +3.226906299999999987e+02 1.387599099999999908e+02 4.564341099999999756e+02 1.899251099999999894e+02 +2.597951400000000035e+02 1.359850399999999979e+02 4.243488500000000272e+02 1.898659700000000043e+02 +2.609577899999999886e+02 1.679579099999999983e+02 4.245833099999999831e+02 2.069820400000000120e+02 +3.232767900000000054e+02 1.715852299999999957e+02 4.563168800000000260e+02 2.086824400000000139e+02 +3.150161699999999882e+02 2.376190400000000125e+02 4.504939400000000091e+02 2.430843500000000006e+02 +2.615342800000000238e+02 2.722495400000000245e+02 4.229323600000000170e+02 2.602702699999999822e+02 +2.843377100000000155e+02 2.706373300000000199e+02 4.345491799999999785e+02 2.606122799999999984e+02 +3.165208400000000211e+02 3.306116599999999721e+02 4.491161900000000173e+02 2.890138600000000224e+02 +2.994231100000000083e+02 3.597612500000000182e+02 4.395611599999999726e+02 3.028581100000000106e+02 +2.541969899999999996e+02 4.175856200000000058e+02 4.182032600000000002e+02 3.301561800000000062e+02 +1.484055400000000020e+02 4.088754299999999944e+02 3.732707500000000209e+02 3.302637300000000096e+02 +1.363789299999999969e+02 3.112467300000000137e+02 3.649461499999999887e+02 2.824778499999999894e+02 +1.341020399999999881e+02 1.730726099999999974e+02 3.628068999999999846e+02 2.122596100000000092e+02 +2.236309200000000033e+02 2.367838800000000106e+02 4.060594199999999887e+02 2.489460200000000043e+02 +2.237529999999999859e+02 2.510873799999999960e+02 4.061863299999999981e+02 2.571727500000000077e+02 +2.515199900000000071e+02 2.541112400000000093e+02 4.197573800000000119e+02 2.533223800000000097e+02 +3.215183000000000106e+02 4.430418399999999792e+02 4.511779799999999909e+02 3.412158800000000269e+02 +2.831653699999999958e+02 2.885740200000000186e+02 4.341974799999999846e+02 2.681152200000000221e+02 +2.585463799999999992e+02 2.877533900000000244e+02 4.217707500000000209e+02 2.681152200000000221e+02 +1.490401200000000017e+02 2.910826000000000136e+02 3.710422800000000052e+02 2.746232200000000034e+02 +2.603124900000000252e+02 2.408638799999999947e+02 4.239777799999999957e+02 2.441781799999999976e+02 +2.434460722294301149e+02 2.803174670547233518e+02 4.122604712625906700e+02 2.872925018567745497e+02 +2.733039336263267387e+02 4.925272668931202134e+02 4.241827474016961332e+02 3.639119046943361582e+02 +5.872857669477701847e+01 4.268111770917495846e+02 3.644077504740858444e+02 3.464977785170541438e+02 +8.004596266233613733e+01 4.657705376531507682e+02 3.708833822336325738e+02 3.569927679204574247e+02 +1.109194182015597079e+02 5.466295878749267558e+02 3.603883928302292361e+02 3.824486996648824970e+02 +3.020408096348485287e+02 4.084341202231640864e+02 4.401056527667184355e+02 3.250612044164856229e+02 +2.402938985564013592e+02 3.613888546395853609e+02 4.124167445534841363e+02 3.027314397283934682e+02 +1.256210636964280951e+02 3.562432787163814396e+02 3.628446669459195277e+02 3.047411185503217439e+02 +1.131246650257899944e+02 3.893219810798351546e+02 3.577088210676583913e+02 3.208185491257481772e+02 +1.168000763995070770e+02 3.555081964416379492e+02 3.590486069489438705e+02 3.042945232565599554e+02 +2.511865675635852710e+02 2.673496829800556043e+02 4.184642774916459871e+02 2.594638950006838627e+02 +2.144765395951733922e+02 2.560330578168609463e+02 4.015903118942297851e+02 2.573677502059737776e+02 +2.260691800062507753e+02 2.215311518315115222e+02 4.066210594015339552e+02 2.316899764707751501e+02 +2.487204598853905395e+02 6.543366483336311035e+01 4.200943651318810339e+02 1.384655147177983849e+02 +2.892832720200565291e+02 6.791710231099571615e+01 4.423739670532051491e+02 1.400020389882345455e+02 +2.093993664895408813e+02 1.357977266996203980e+02 3.993512874809929940e+02 1.807199321547924740e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c4c00ab7cffad294a2d3b1c0761783529b18c8f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/crossval_errors.txt @@ -0,0 +1,40 @@ +3.107206334320491425e+00 +6.237904612903644175e-01 +3.844272521878835835e+00 +4.161468221374081677e-01 +6.602655361141566148e-01 +2.933609060206928687e-01 +1.621433032763675008e+00 +8.156725260303979708e-01 +1.374326397156321988e+00 +5.699350862722842859e-01 +2.645182804425461054e-01 +2.216777645068836122e+00 +2.065341874831398972e-01 +2.487949740495516426e+00 +1.915661557083224054e+00 +1.524632930681501408e+00 +8.505375980072693576e-01 +1.324056400088596597e+00 +2.622759999458317726e+00 +1.273873093214882291e+00 +9.289628752266489986e-01 +1.223324686678771839e-01 +2.363886694428318780e+00 +1.428140686106678992e+00 +7.601475707330846099e+00 +2.027895169522786478e+00 +1.569901890632411812e+00 +1.942497203188160437e+00 +3.283187903146319364e+00 +3.453279298906811512e+00 +2.921981659860365954e-01 +1.235501646779935125e+00 +2.856988289013143945e+00 +1.343701685909615584e+00 +1.528449211300779431e-01 +6.289527516872593926e-01 +6.832133453264218614e-01 +1.883882880515432845e+00 +7.781953177049958370e-01 +7.535427800225099615e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/01.png new file mode 100644 index 0000000000000000000000000000000000000000..c85e935768a0ff0631d7798063fd549a8ad7b237 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd32df9ca232efdd702f4694c74f262286989e39ed26ef8a2960012c58a987e9 +size 1003664 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/02.png new file mode 100644 index 0000000000000000000000000000000000000000..98d9c30ecdf718c3a00932987710d8d07556db66 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd51d44e58fe701ac03c3034c28efb14bb91558650ce8aaff0587c18f3c3c37f +size 882713 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..425014b75a07d62405a16ec51b1a87c0b4418d35 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/corrs.txt @@ -0,0 +1,51 @@ +4.594434400000000096e+02 1.421354699999999980e+02 4.110907700000000204e+02 1.401134500000000003e+02 +4.247887299999999868e+02 1.560969499999999925e+02 3.760649900000000230e+02 1.548364300000000071e+02 +3.429439899999999852e+02 2.414490099999999870e+02 2.926467599999999720e+02 2.486593699999999956e+02 +3.713144500000000221e+02 2.355873400000000117e+02 3.224240199999999845e+02 2.410392099999999971e+02 +4.057173999999999978e+02 2.017435800000000086e+02 3.556643199999999752e+02 2.049583299999999895e+02 +4.831266400000000090e+02 1.719836200000000019e+02 4.367282000000000153e+02 1.692571599999999989e+02 +4.923009200000000192e+02 2.504372400000000027e+02 4.481890500000000088e+02 2.526162500000000080e+02 +4.432739000000000260e+02 2.727330000000000041e+02 3.958687699999999836e+02 2.772373200000000111e+02 +4.377047999999999774e+02 3.714216400000000249e+02 3.913257199999999898e+02 3.837612199999999802e+02 +4.153218800000000215e+02 3.899476099999999974e+02 3.685330099999999902e+02 4.022775100000000066e+02 +5.751513099999999667e+02 3.342953499999999849e+02 5.388746999999999616e+02 3.367766100000000051e+02 +3.755078700000000254e+02 2.774814700000000016e+02 3.261097700000000259e+02 2.857081999999999766e+02 +2.731549099999999726e+02 4.166197399999999789e+02 1.820476999999999919e+02 4.358663999999999987e+02 +3.387235900000000015e+02 4.110980200000000195e+02 2.598591299999999933e+02 4.305704999999999814e+02 +5.390134799999999586e+02 2.733323599999999942e+02 4.980345800000000054e+02 2.736875699999999938e+02 +2.177295000000000016e+02 1.132552699999999959e+02 1.542236399999999890e+02 1.107051900000000018e+02 +1.145859999999999985e+02 2.526850699999999961e+02 5.192982200000000148e+01 2.623476499999999874e+02 +4.961008100000000240e+02 3.077200500000000147e+02 4.523599899999999820e+02 3.136505300000000034e+02 +4.634293799999999806e+02 3.142657500000000255e+02 4.179203800000000228e+02 3.208705800000000181e+02 +3.886099899999999820e+02 3.253458499999999844e+02 3.399249699999999734e+02 3.348417499999999905e+02 +5.037316900000000146e+02 4.045523200000000088e+02 4.621204399999999737e+02 4.156023400000000265e+02 +7.397184899999999743e+02 4.309706199999999967e+02 7.718919399999999769e+02 4.487039700000000266e+02 +7.284640900000000556e+02 4.433973500000000172e+02 7.557137400000000298e+02 4.653510999999999740e+02 +6.990568700000000035e+02 3.712936900000000264e+02 6.837079499999999825e+02 3.754366200000000049e+02 +4.164350800000000277e+02 2.141025300000000016e+02 3.669003900000000158e+02 2.148543400000000076e+02 +4.680218800000000101e+02 1.948160900000000026e+02 4.207339799999999741e+02 1.938685499999999990e+02 +4.289209500000000048e+02 4.063602700000000141e+02 3.828354699999999866e+02 4.189246400000000108e+02 +4.493195400000000177e+02 3.867823099999999954e+02 4.037029899999999998e+02 3.984088100000000168e+02 +6.950526099999999587e+02 4.259295900000000188e+02 7.128063600000000406e+02 4.427250700000000165e+02 +2.044560836833049393e+02 1.242459178115272778e+02 1.406147127704518027e+02 1.236865742051300145e+02 +2.281308525063912498e+02 1.210389770147745594e+02 1.662544620520748424e+02 1.197919287446302974e+02 +2.348276994643161402e+02 1.438019680727641116e+02 1.722587071370118679e+02 1.445662008445619620e+02 +6.717463228213715638e+02 3.649054578569836167e+02 6.445953051321157545e+02 3.698576462864243695e+02 +6.726109392126528519e+02 4.008951151440707577e+02 6.462311986451652501e+02 4.084647331943905897e+02 +6.521843769686304313e+02 3.636085332700615709e+02 6.234377490300099680e+02 3.709482419617906430e+02 +4.940707018408172075e+02 1.776044948851611878e+02 4.486826560612118442e+02 1.751879062404033505e+02 +4.926196935077252306e+02 2.074307772876081799e+02 4.476141575501331999e+02 2.061743630616842609e+02 +4.855258749903864555e+02 1.848595365506212715e+02 4.403812445520623555e+02 1.831605489769133044e+02 +8.207629282561811124e+01 2.572739565469469198e+02 1.809768513909574494e+01 2.677228259828372074e+02 +7.478616494490586319e+01 2.775595297802331629e+02 6.228466757917601626e+00 2.887637131131075421e+02 +5.388479506127032437e+02 1.739406870605558879e+02 4.964077749761838732e+02 1.681577771226033917e+02 +5.791628298642275468e+02 1.504787491354884708e+02 5.394336916079546427e+02 1.415599377502360596e+02 +5.418219990820780367e+02 1.924458775366653924e+02 4.983634984594461912e+02 1.873238672585739835e+02 +6.054122596184905660e+02 4.225034913737084707e+02 5.760699684299472665e+02 4.370432786270732208e+02 +6.292047167561709102e+02 4.238849888849285890e+02 6.045560925447747422e+02 4.376367395461321053e+02 +6.330922480021150704e+02 3.748893448863988738e+02 6.008324998773894094e+02 3.828873228498626986e+02 +6.940164133197212095e+02 3.739800289861361193e+02 6.688765891054499662e+02 3.792636731276583077e+02 +7.219021009277796566e+02 3.603402904821944048e+02 6.998789256176431763e+02 3.643664464919290822e+02 +7.322076811307578055e+02 4.003501900937566234e+02 7.123603857719028838e+02 4.086554986522052104e+02 +6.843170437169181923e+02 4.079278225959464521e+02 6.600187786733947632e+02 4.175133090842604133e+02 +6.993143883506772909e+02 4.338220179609293723e+02 7.397286915533125011e+02 4.586266329625896105e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7c2bf71fa21c921dc45a9dabc6915ee11230c9f --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/crossval_errors.txt @@ -0,0 +1,51 @@ +1.059984065478942972e+00 +1.172619320842898327e+00 +2.759231593787795367e-01 +1.246239438681858713e+00 +1.478182636100590752e+00 +6.418757997184851294e-02 +5.476242100183045247e-01 +2.837136386542546962e-01 +2.091690129035768120e+00 +3.300151015928541320e-01 +1.270728842324514885e+00 +1.133749249740539478e+00 +1.009659388247202694e+00 +3.648056175214751118e+00 +5.856684777977440998e-02 +8.687130066573967024e-01 +6.579392662125330693e-01 +6.064583546252843016e-01 +6.361053578088456117e-02 +3.291927550264510116e-01 +1.698324996372617024e+00 +1.442326735645127700e+00 +1.464286713857532840e+00 +5.724373708212716627e-01 +6.418198956336912397e-01 +5.189720473721508576e-01 +1.443933549361242852e+00 +1.376498883837392051e-01 +1.606222283955370589e+00 +3.479913123925851837e-01 +5.277859297109415149e-01 +1.946469222702315705e+00 +1.392742694583132490e+00 +2.400295113941870007e+00 +1.841967347264411803e+00 +2.352230930271596021e-01 +4.439526535957881159e-01 +1.145484789065688735e+00 +7.598014255470626477e-01 +7.455699952272273334e-01 +5.072581507499479558e-01 +7.148961590925400067e-01 +1.950846228854876463e+00 +1.470122079859542685e+00 +1.948916047213692104e-01 +1.393201192210905548e-01 +1.750177991124282739e+00 +8.992849163506809740e-01 +2.437138945474842666e-02 +4.008038096484493606e-01 +2.857518653713483348e-01 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/01.png new file mode 100644 index 0000000000000000000000000000000000000000..5a893e4c9ae1e30ab4e76c2808e2648889a7c6d6 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ea892f78feda9190d466ec48584ccbd178f4645076c49add4b43cbe444b2dcf +size 608612 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/02.png new file mode 100644 index 0000000000000000000000000000000000000000..4b0cbf6e83fbc1cdbf92897f3fe73286c50ed7df --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86ef3afd1ac043c1d51702a7efc71774cf402a989072f4fc2fb19c29cbbe9288 +size 584762 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..32daf93f65fb756166ea98850677583674250d46 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/corrs.txt @@ -0,0 +1,38 @@ +4.797559299999999780e+02 3.636057400000000257e+02 5.910282999999999447e+02 2.342230399999999975e+02 +4.575773899999999799e+02 3.698191100000000233e+02 5.583859200000000556e+02 2.408268300000000011e+02 +2.726085100000000239e+02 3.609577899999999886e+02 2.741034999999999968e+02 2.178287100000000009e+02 +2.530779300000000092e+02 3.872298099999999863e+02 2.463365000000000009e+02 2.571533800000000269e+02 +1.570759799999999871e+02 3.638692599999999970e+02 9.238810100000000602e+01 2.168714699999999880e+02 +3.391323499999999740e+02 3.876095199999999750e+02 3.732686800000000176e+02 2.604542500000000018e+02 +5.019839099999999803e+02 3.695469499999999812e+02 6.249612799999999879e+02 2.458215199999999925e+02 +5.469841999999999871e+02 3.587206800000000158e+02 6.836004000000000360e+02 2.281850000000000023e+02 +4.328263000000000034e+02 3.796009999999999991e+02 5.088481100000000197e+02 2.551237600000000043e+02 +4.344319499999999721e+02 3.635670099999999820e+02 5.068776199999999790e+02 2.286345699999999965e+02 +2.574950999999999794e+02 4.067603899999999726e+02 2.631620700000000284e+02 2.901571500000000015e+02 +4.085600499999999897e+02 4.326710199999999986e+02 4.939809200000000260e+02 3.393498299999999972e+02 +5.832414499999999862e+02 3.753785300000000120e+02 7.470578500000000304e+02 2.571437000000000239e+02 +2.331665999999999883e+02 3.660773300000000177e+02 2.143684700000000021e+02 2.241108900000000119e+02 +3.921076199999999972e+02 3.674550800000000095e+02 4.443483600000000138e+02 2.336271900000000130e+02 +4.930527399999999716e+02 3.986024699999999825e+02 6.086143299999999954e+02 2.861023999999999887e+02 +5.256456699999999955e+02 4.605939799999999877e+02 6.824474400000000287e+02 3.859730500000000006e+02 +4.944595400000000041e+02 4.609456799999999816e+02 6.369771700000000010e+02 3.860902899999999818e+02 +5.826532099999999446e+02 4.119304099999999949e+02 7.677974199999999882e+02 3.165097400000000221e+02 +4.631633807972402792e+02 3.838893306705910504e+02 5.657415136605623047e+02 2.633477841694225390e+02 +4.772016069320450811e+02 3.809024740461644569e+02 5.853869735965520249e+02 2.612353691225419539e+02 +4.999017172776869415e+02 3.826945880208203903e+02 6.215092708982107297e+02 2.637702671787986901e+02 +2.298391962234342714e+02 3.919639079503628523e+02 2.112270426050509968e+02 2.649053558881469144e+02 +2.258501706000224658e+02 3.922488383520351363e+02 2.065109531980616566e+02 2.662808819651854719e+02 +1.987817824411567926e+02 3.950981423687578626e+02 1.648521634363227690e+02 2.678529117675152520e+02 +2.467925551229343455e+02 3.931036295570519314e+02 2.385410604205307550e+02 2.670668968663503620e+02 +2.581897711898251941e+02 3.922488383520351363e+02 2.544578621691196645e+02 2.676564080422240295e+02 +2.522062327547075427e+02 3.617612853731021687e+02 2.428641423769375933e+02 2.187269804447099375e+02 +2.131707677256064812e+02 3.660352413981862583e+02 1.835200173389887652e+02 2.234430698516992493e+02 +1.701462770730936143e+02 3.785721790717661861e+02 1.198528103446330704e+02 2.419144200290740514e+02 +1.983543868386483950e+02 3.688845454149090415e+02 1.644591559857403240e+02 2.281591592586885611e+02 +5.589168186820540996e+02 4.165899106386119115e+02 7.329688054814949965e+02 3.228653612625751634e+02 +6.007144447556318028e+02 4.167355469663943950e+02 7.965020870332186860e+02 3.259219452547543483e+02 +4.800636057737958708e+02 4.608663719209688452e+02 6.134348439460603686e+02 3.839675846978299205e+02 +1.447956348980008556e+02 3.697010709907046930e+02 6.855698710166603860e+01 2.243443305002154773e+02 +1.529894419054618879e+02 3.697010709907046930e+02 8.174821679819848441e+01 2.246825671591009268e+02 +1.623971462473616043e+02 3.748601346620690720e+02 1.061012562379507358e+02 2.358443769023207039e+02 +4.879100783360448190e+02 3.810894645944220542e+02 5.999324939202114138e+02 2.625451796870169119e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ad28740ce26bd03baeb9766ce5198e11d65d5a8 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/crossval_errors.txt @@ -0,0 +1,38 @@ +3.396145170153651027e-01 +1.985991266213804840e+00 +8.730360763178448558e-01 +1.759074792352649741e+00 +6.690911651420399231e-01 +3.127928775947395046e+00 +2.242026380339626535e+00 +1.505917217336340819e-01 +2.060474379419809488e+00 +1.356362993459239386e+00 +4.635784250486629787e-02 +2.898097811114895173e+00 +7.993129866486650137e-01 +3.993484324439461886e-01 +1.941237862719949803e+00 +2.113590891431526231e+00 +1.301169151744067665e-01 +1.814139350301335440e+00 +1.942498926613547283e+00 +4.498550778761978841e-01 +2.105282325971198798e+00 +1.643782548361214652e-02 +7.447399375311934688e-02 +1.379783288516800654e+00 +8.107935155796172078e-01 +6.991320042190590778e-01 +1.313621562781517982e+00 +2.148165976067879790e-01 +2.133775335796964101e-01 +1.210735840397118235e+00 +9.441267027822466407e-01 +3.666260745439142266e+00 +3.918843643756335204e+00 +1.097652941210957911e+00 +6.105290073251094796e-02 +1.356574778919217295e-01 +1.617277148218513982e+00 +2.955076440033605589e+00 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/02.png new file mode 100644 index 0000000000000000000000000000000000000000..430c0e076263407c36895282121b585e09d5cd98 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c2abc34567f59818457ba0c69eef03868b6b0be276c73388bade0fbfff21970 +size 540896 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/corrs.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a045058924894eba94c7b02ccf8038a8d67f1ef --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/corrs.txt @@ -0,0 +1,36 @@ +4.615342800000000238e+02 2.389287400000000048e+02 4.738835399999999822e+02 2.814227500000000077e+02 +4.576365299999999934e+02 2.757140499999999861e+02 4.694200299999999970e+02 3.442165600000000154e+02 +2.591256500000000074e+02 2.957716599999999971e+02 1.176534300000000002e+02 3.840285000000000082e+02 +2.595849000000000046e+02 2.628464000000000169e+02 1.152990800000000036e+02 3.305369299999999839e+02 +3.461092899999999872e+02 1.858082300000000089e+02 4.895454199999999787e+02 1.801452899999999886e+02 +3.563085899999999810e+02 1.855737599999999929e+02 5.023238499999999931e+02 1.801452899999999886e+02 +3.703001499999999737e+02 2.577946499999999901e+02 3.055337799999999788e+02 3.182941700000000083e+02 +3.900855999999999995e+02 2.519620400000000018e+02 3.353916100000000142e+02 3.058674500000000194e+02 +5.045426400000000058e+02 2.437352999999999952e+02 6.176713799999999992e+02 2.771141700000000014e+02 +5.591871800000000121e+02 2.706719800000000191e+02 6.825259399999999914e+02 3.283976799999999798e+02 +4.768939100000000053e+02 2.398385900000000106e+02 5.635075199999999995e+02 2.751426399999999717e+02 +3.574819499999999834e+02 2.812627499999999827e+02 2.861010600000000181e+02 3.566723400000000197e+02 +4.566492099999999823e+02 2.588013300000000072e+02 4.655793400000000020e+02 3.148266300000000228e+02 +2.816849300000000085e+02 2.655427700000000186e+02 1.499537899999999979e+02 3.346798600000000192e+02 +2.157655900000000031e+02 2.429049899999999980e+02 1.148688800000000043e+02 2.944754100000000108e+02 +2.927736800000000130e+02 2.799935800000000086e+02 1.727572199999999896e+02 3.558925100000000157e+02 +3.823568500000000085e+02 2.611653699999999958e+02 3.271164600000000178e+02 3.231802799999999820e+02 +4.123609700000000089e+02 2.038438199999999938e+02 6.026451100000000451e+02 2.004868299999999977e+02 +3.833538500000000226e+02 2.794558299999999917e+02 3.290319600000000264e+02 3.525691800000000171e+02 +3.647116899999999760e+02 2.336124299999999891e+02 3.032482499999999845e+02 2.754504200000000083e+02 +3.588877200000000016e+02 2.104272200000000055e+02 5.073648800000000278e+02 2.216458800000000053e+02 +3.409510200000000282e+02 2.089031799999999919e+02 4.814563299999999799e+02 2.183633499999999970e+02 +3.613904200000000060e+02 2.614288799999999924e+02 2.906451500000000010e+02 3.242730700000000184e+02 +4.981131399999999871e+02 2.399558199999999886e+02 6.060632100000000264e+02 2.727979700000000207e+02 +4.695820600000000127e+02 2.615820600000000127e+02 5.372573300000000245e+02 3.162080900000000270e+02 +1.721220706987422773e+02 2.428520120853525555e+02 2.947191743642216011e+01 2.989297157012354091e+02 +1.647649712238389554e+02 2.405529184994452692e+02 1.660567796869992208e+01 2.947568812792713970e+02 +1.954195523692694394e+02 2.331958190245419473e+02 6.946158064691022105e+01 2.806735651051429841e+02 +1.744211642846495636e+02 2.557269361664333474e+02 3.260154325289514077e+01 3.187506792055642109e+02 +1.800922617965541974e+02 2.594054859038849941e+02 4.425070601421123229e+01 3.236189860311888538e+02 +1.558751426916641094e+02 2.331958190245419473e+02 1.826889390910878319e+00 2.846725314261918243e+02 +1.566415072202998715e+02 2.754991410052360266e+02 4.608779005553515162e+00 3.490037287648030428e+02 +1.960326439921780661e+02 2.738131390422373670e+02 7.293894266521351710e+01 3.441354219391784000e+02 +4.117581439734341870e+02 2.195749695714285963e+02 6.025075896224828966e+02 2.308977018826773815e+02 +4.066962080099323202e+02 2.078240467990135585e+02 5.947414811149388925e+02 2.116562987147324293e+02 +4.067502800813609269e+02 2.193798131157576563e+02 5.957846897204299239e+02 2.306658781902932844e+02 diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/crossval_errors.txt new file mode 100644 index 0000000000000000000000000000000000000000..aec67f6c0ec471fac7a21dd8402b563e6906ec51 --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/crossval_errors.txt @@ -0,0 +1,36 @@ +2.698356670634800558e+00 +3.569808760282799365e-01 +2.495950953406511630e+00 +7.742120527396407770e-01 +1.402599750820609226e-01 +2.236933413898831713e+00 +7.785110674558124444e-01 +1.881489831117169276e+00 +3.487995336350284248e+00 +3.396485994228788829e+00 +6.879465990330972947e-01 +5.618151694300682619e-01 +5.303773159161432327e-01 +7.361475832658961882e-02 +8.930819946990574687e-01 +1.779976084615252585e+00 +4.086245552329793029e-01 +5.164113151014938730e+00 +1.462905037131466912e+00 +7.370368629768874191e-01 +1.987278302009876985e+00 +9.721266174071729882e-01 +1.473186291581487783e-01 +6.483443779751860703e-01 +2.697537697590622230e+00 +1.015760142012948029e-02 +1.506916098273485272e+00 +2.212622623304430380e+00 +1.220237821851318932e-01 +4.458375803433419216e-01 +3.622931236761927076e-01 +9.706837058276953645e-01 +1.245315119310226093e+00 +1.259635152519284762e-01 +2.739689609504042389e-01 +9.760547549081137475e-02 diff --git a/imcui/datasets/wxbs_benchmark/download.py b/imcui/datasets/wxbs_benchmark/download.py new file mode 100644 index 0000000000000000000000000000000000000000..afbcc20568c56786e547998ac1748173f44474ee --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/download.py @@ -0,0 +1,4 @@ +from wxbs_benchmark.dataset import * # noqa: F403 + +dset = EVDDataset(".EVD", download=True) # noqa: F405 +dset = WxBSDataset(".WxBS", subset="test", download=True) # noqa: F405 diff --git a/imcui/datasets/wxbs_benchmark/example.py b/imcui/datasets/wxbs_benchmark/example.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd81950482920b39184086f525a60ef35c0ab7c --- /dev/null +++ b/imcui/datasets/wxbs_benchmark/example.py @@ -0,0 +1,19 @@ +import os +from pathlib import Path + +ROOT_PATH = Path("/teamspace/studios/this_studio/image-matching-webui") +prefix = "datasets/wxbs_benchmark/.WxBS/v1.1" +wxbs_path = ROOT_PATH / prefix + +pairs = [] +for catg in os.listdir(wxbs_path): + catg_path = wxbs_path / catg + if not catg_path.is_dir(): + continue + for scene in os.listdir(catg_path): + scene_path = catg_path / scene + if not scene_path.is_dir(): + continue + img1_path = scene_path / "01.png" + img2_path = scene_path / "02.png" + pairs.append([str(img1_path), str(img2_path)]) diff --git a/imcui/hloc/__init__.py b/imcui/hloc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf2753de6b5a09703aa58067891c9cd4bb5dee7 --- /dev/null +++ b/imcui/hloc/__init__.py @@ -0,0 +1,68 @@ +import logging +import sys + +import torch +from packaging import version + +__version__ = "1.5" + +LOG_PATH = "log.txt" + + +def read_logs(): + sys.stdout.flush() + with open(LOG_PATH, "r") as f: + return f.read() + + +def flush_logs(): + sys.stdout.flush() + logs = open(LOG_PATH, "w") + logs.close() + + +formatter = logging.Formatter( + fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", + datefmt="%Y/%m/%d %H:%M:%S", +) + +logs_file = open(LOG_PATH, "w") +logs_file.close() + +file_handler = logging.FileHandler(filename=LOG_PATH) +file_handler.setFormatter(formatter) +file_handler.setLevel(logging.INFO) +stdout_handler = logging.StreamHandler() +stdout_handler.setFormatter(formatter) +stdout_handler.setLevel(logging.INFO) +logger = logging.getLogger("hloc") +logger.setLevel(logging.INFO) +logger.addHandler(file_handler) +logger.addHandler(stdout_handler) +logger.propagate = False + +try: + import pycolmap +except ImportError: + logger.warning("pycolmap is not installed, some features may not work.") +else: + min_version = version.parse("0.6.0") + found_version = pycolmap.__version__ + if found_version != "dev": + version = version.parse(found_version) + if version < min_version: + s = f"pycolmap>={min_version}" + logger.warning( + "hloc requires %s but found pycolmap==%s, " + 'please upgrade with `pip install --upgrade "%s"`', + s, + found_version, + s, + ) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# model hub: https://huggingface.co/Realcat/imcui_checkpoints +MODEL_REPO_ID = "Realcat/imcui_checkpoints" + +DATASETS_REPO_ID = "Realcat/imcui_datasets" diff --git a/imcui/hloc/colmap_from_nvm.py b/imcui/hloc/colmap_from_nvm.py new file mode 100644 index 0000000000000000000000000000000000000000..121ac42182c1942a96d5b1585319cdc634d40db7 --- /dev/null +++ b/imcui/hloc/colmap_from_nvm.py @@ -0,0 +1,216 @@ +import argparse +import sqlite3 +from collections import defaultdict +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from . import logger +from .utils.read_write_model import ( + CAMERA_MODEL_NAMES, + Camera, + Image, + Point3D, + write_model, +) + + +def recover_database_images_and_ids(database_path): + images = {} + cameras = {} + db = sqlite3.connect(str(database_path)) + ret = db.execute("SELECT name, image_id, camera_id FROM images;") + for name, image_id, camera_id in ret: + images[name] = image_id + cameras[name] = camera_id + db.close() + logger.info(f"Found {len(images)} images and {len(cameras)} cameras in database.") + return images, cameras + + +def quaternion_to_rotation_matrix(qvec): + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array( + [ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w, + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w, + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y, + ], + ] + ) + return R + + +def camera_center_to_translation(c, qvec): + R = quaternion_to_rotation_matrix(qvec) + return (-1) * np.matmul(R, c) + + +def read_nvm_model(nvm_path, intrinsics_path, image_ids, camera_ids, skip_points=False): + with open(intrinsics_path, "r") as f: + raw_intrinsics = f.readlines() + + logger.info(f"Reading {len(raw_intrinsics)} cameras...") + cameras = {} + for intrinsics in raw_intrinsics: + intrinsics = intrinsics.strip("\n").split(" ") + name, camera_model, width, height = intrinsics[:4] + params = [float(p) for p in intrinsics[4:]] + camera_model = CAMERA_MODEL_NAMES[camera_model] + assert len(params) == camera_model.num_params + camera_id = camera_ids[name] + camera = Camera( + id=camera_id, + model=camera_model.model_name, + width=int(width), + height=int(height), + params=params, + ) + cameras[camera_id] = camera + + nvm_f = open(nvm_path, "r") + line = nvm_f.readline() + while line == "\n" or line.startswith("NVM_V3"): + line = nvm_f.readline() + num_images = int(line) + assert num_images == len(cameras) + + logger.info(f"Reading {num_images} images...") + image_idx_to_db_image_id = [] + image_data = [] + i = 0 + while i < num_images: + line = nvm_f.readline() + if line == "\n": + continue + data = line.strip("\n").split(" ") + image_data.append(data) + image_idx_to_db_image_id.append(image_ids[data[0]]) + i += 1 + + line = nvm_f.readline() + while line == "\n": + line = nvm_f.readline() + num_points = int(line) + + if skip_points: + logger.info(f"Skipping {num_points} points.") + num_points = 0 + else: + logger.info(f"Reading {num_points} points...") + points3D = {} + image_idx_to_keypoints = defaultdict(list) + i = 0 + pbar = tqdm(total=num_points, unit="pts") + while i < num_points: + line = nvm_f.readline() + if line == "\n": + continue + + data = line.strip("\n").split(" ") + x, y, z, r, g, b, num_observations = data[:7] + obs_image_ids, point2D_idxs = [], [] + for j in range(int(num_observations)): + s = 7 + 4 * j + img_index, kp_index, kx, ky = data[s : s + 4] + image_idx_to_keypoints[int(img_index)].append( + (int(kp_index), float(kx), float(ky), i) + ) + db_image_id = image_idx_to_db_image_id[int(img_index)] + obs_image_ids.append(db_image_id) + point2D_idxs.append(kp_index) + + point = Point3D( + id=i, + xyz=np.array([x, y, z], float), + rgb=np.array([r, g, b], int), + error=1.0, # fake + image_ids=np.array(obs_image_ids, int), + point2D_idxs=np.array(point2D_idxs, int), + ) + points3D[i] = point + + i += 1 + pbar.update(1) + pbar.close() + + logger.info("Parsing image data...") + images = {} + for i, data in enumerate(image_data): + # Skip the focal length. Skip the distortion and terminal 0. + name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data + qvec = np.array([qw, qx, qy, qz], float) + c = np.array([cx, cy, cz], float) + t = camera_center_to_translation(c, qvec) + + if i in image_idx_to_keypoints: + # NVM only stores triangulated 2D keypoints: add dummy ones + keypoints = image_idx_to_keypoints[i] + point2D_idxs = np.array([d[0] for d in keypoints]) + tri_xys = np.array([[x, y] for _, x, y, _ in keypoints]) + tri_ids = np.array([i for _, _, _, i in keypoints]) + + num_2Dpoints = max(point2D_idxs) + 1 + xys = np.zeros((num_2Dpoints, 2), float) + point3D_ids = np.full(num_2Dpoints, -1, int) + xys[point2D_idxs] = tri_xys + point3D_ids[point2D_idxs] = tri_ids + else: + xys = np.zeros((0, 2), float) + point3D_ids = np.full(0, -1, int) + + image_id = image_ids[name] + image = Image( + id=image_id, + qvec=qvec, + tvec=t, + camera_id=camera_ids[name], + name=name, + xys=xys, + point3D_ids=point3D_ids, + ) + images[image_id] = image + + return cameras, images, points3D + + +def main(nvm, intrinsics, database, output, skip_points=False): + assert nvm.exists(), nvm + assert intrinsics.exists(), intrinsics + assert database.exists(), database + + image_ids, camera_ids = recover_database_images_and_ids(database) + + logger.info("Reading the NVM model...") + model = read_nvm_model( + nvm, intrinsics, image_ids, camera_ids, skip_points=skip_points + ) + + logger.info("Writing the COLMAP model...") + output.mkdir(exist_ok=True, parents=True) + write_model(*model, path=str(output), ext=".bin") + logger.info("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--nvm", required=True, type=Path) + parser.add_argument("--intrinsics", required=True, type=Path) + parser.add_argument("--database", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--skip_points", action="store_true") + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/extract_features.py b/imcui/hloc/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6a5c76f7c8bffa41fb82b4d0ec3dbf09ffcf3e --- /dev/null +++ b/imcui/hloc/extract_features.py @@ -0,0 +1,607 @@ +import argparse +import collections.abc as collections +import pprint +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Optional, Union + +import cv2 +import h5py +import numpy as np +import PIL.Image +import torch +import torchvision.transforms.functional as F +from tqdm import tqdm + +from . import extractors, logger +from .utils.base_model import dynamic_load +from .utils.io import list_h5_names, read_image +from .utils.parsers import parse_image_lists + +""" +A set of standard configurations that can be directly selected from the command +line using their name. Each is a dictionary with the following entries: + - output: the name of the feature file that will be generated. + - model: the model configuration, as passed to a feature extractor. + - preprocessing: how to preprocess the images read from disk. +""" +confs = { + "superpoint_aachen": { + "output": "feats-superpoint-n4096-r1024", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + # Resize images to 1600px even if they are originally smaller. + # Improves the keypoint localization if the images are of good quality. + "superpoint_max": { + "output": "feats-superpoint-n4096-rmax1600", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "superpoint_inloc": { + "output": "feats-superpoint-n4096-r1600", + "model": { + "name": "superpoint", + "nms_radius": 4, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1600, + }, + }, + "r2d2": { + "output": "feats-r2d2-n5000-r1024", + "model": { + "name": "r2d2", + "max_keypoints": 5000, + "reliability_threshold": 0.7, + "repetability_threshold": 0.7, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "d2net-ss": { + "output": "feats-d2net-ss-n5000-r1600", + "model": { + "name": "d2net", + "multiscale": False, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, + "d2net-ms": { + "output": "feats-d2net-ms-n5000-r1600", + "model": { + "name": "d2net", + "multiscale": True, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, + "rord": { + "output": "feats-rord-ss-n5000-r1600", + "model": { + "name": "rord", + "multiscale": False, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, + "rootsift": { + "output": "feats-rootsift-n5000-r1600", + "model": { + "name": "dog", + "descriptor": "rootsift", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "sift": { + "output": "feats-sift-n5000-r1600", + "model": { + "name": "sift", + "rootsift": True, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "sosnet": { + "output": "feats-sosnet-n5000-r1600", + "model": { + "name": "dog", + "descriptor": "sosnet", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "hardnet": { + "output": "feats-hardnet-n5000-r1600", + "model": { + "name": "dog", + "descriptor": "hardnet", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1600, + "force_resize": True, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "disk": { + "output": "feats-disk-n5000-r1600", + "model": { + "name": "disk", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, + "xfeat": { + "output": "feats-xfeat-n5000-r1600", + "model": { + "name": "xfeat", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, + "aliked-n16-rot": { + "output": "feats-aliked-n16-rot", + "model": { + "name": "aliked", + "model_name": "aliked-n16rot", + "max_num_keypoints": -1, + "detection_threshold": 0.2, + "nms_radius": 2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1024, + }, + }, + "aliked-n16": { + "output": "feats-aliked-n16", + "model": { + "name": "aliked", + "model_name": "aliked-n16", + "max_num_keypoints": -1, + "detection_threshold": 0.2, + "nms_radius": 2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1024, + }, + }, + "alike": { + "output": "feats-alike-n5000-r1600", + "model": { + "name": "alike", + "max_keypoints": 5000, + "use_relu": True, + "multiscale": False, + "detection_threshold": 0.5, + "top_k": -1, + "sub_pixel": False, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, + "lanet": { + "output": "feats-lanet-n5000-r1600", + "model": { + "name": "lanet", + "keypoint_threshold": 0.1, + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1600, + }, + }, + "darkfeat": { + "output": "feats-darkfeat-n5000-r1600", + "model": { + "name": "darkfeat", + "max_keypoints": 5000, + "reliability_threshold": 0.7, + "repetability_threshold": 0.7, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "dedode": { + "output": "feats-dedode-n5000-r1600", + "model": { + "name": "dedode", + "max_keypoints": 5000, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1600, + "width": 768, + "height": 768, + "dfactor": 8, + }, + }, + "example": { + "output": "feats-example-n2000-r1024", + "model": { + "name": "example", + "keypoint_threshold": 0.1, + "max_keypoints": 2000, + "model_name": "model.pth", + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 768, + "height": 768, + "dfactor": 8, + }, + }, + "sfd2": { + "output": "feats-sfd2-n4096-r1600", + "model": { + "name": "sfd2", + "max_keypoints": 4096, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "conf_th": 0.001, + "multiscale": False, + "scales": [1.0], + }, + }, + # Global descriptors + "dir": { + "output": "global-feats-dir", + "model": {"name": "dir"}, + "preprocessing": {"resize_max": 1024}, + }, + "netvlad": { + "output": "global-feats-netvlad", + "model": {"name": "netvlad"}, + "preprocessing": {"resize_max": 1024}, + }, + "openibl": { + "output": "global-feats-openibl", + "model": {"name": "openibl"}, + "preprocessing": {"resize_max": 1024}, + }, + "cosplace": { + "output": "global-feats-cosplace", + "model": {"name": "cosplace"}, + "preprocessing": {"resize_max": 1024}, + }, + "eigenplaces": { + "output": "global-feats-eigenplaces", + "model": {"name": "eigenplaces"}, + "preprocessing": {"resize_max": 1024}, + }, +} + + +def resize_image(image, size, interp): + if interp.startswith("cv2_"): + interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper()) + h, w = image.shape[:2] + if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): + interp = cv2.INTER_LINEAR + resized = cv2.resize(image, size, interpolation=interp) + elif interp.startswith("pil_"): + interp = getattr(PIL.Image, interp[len("pil_") :].upper()) + resized = PIL.Image.fromarray(image.astype(np.uint8)) + resized = resized.resize(size, resample=interp) + resized = np.asarray(resized, dtype=image.dtype) + else: + raise ValueError(f"Unknown interpolation {interp}.") + return resized + + +class ImageDataset(torch.utils.data.Dataset): + default_conf = { + "globs": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"], + "grayscale": False, + "resize_max": None, + "force_resize": False, + "interpolation": "cv2_area", # pil_linear is more accurate but slower + } + + def __init__(self, root, conf, paths=None): + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.root = root + + if paths is None: + paths = [] + for g in conf.globs: + paths += list(Path(root).glob("**/" + g)) + if len(paths) == 0: + raise ValueError(f"Could not find any image in root: {root}.") + paths = sorted(list(set(paths))) + self.names = [i.relative_to(root).as_posix() for i in paths] + logger.info(f"Found {len(self.names)} images in root {root}.") + else: + if isinstance(paths, (Path, str)): + self.names = parse_image_lists(paths) + elif isinstance(paths, collections.Iterable): + self.names = [p.as_posix() if isinstance(p, Path) else p for p in paths] + else: + raise ValueError(f"Unknown format for path argument {paths}.") + + for name in self.names: + if not (root / name).exists(): + raise ValueError(f"Image {name} does not exists in root: {root}.") + + def __getitem__(self, idx): + name = self.names[idx] + image = read_image(self.root / name, self.conf.grayscale) + image = image.astype(np.float32) + size = image.shape[:2][::-1] + + if self.conf.resize_max and ( + self.conf.force_resize or max(size) > self.conf.resize_max + ): + scale = self.conf.resize_max / max(size) + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, self.conf.interpolation) + + if self.conf.grayscale: + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = image / 255.0 + + data = { + "image": image, + "original_size": np.array(size), + } + return data + + def __len__(self): + return len(self.names) + + +def extract(model, image_0, conf): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + "force_resize": False, + "width": 320, + "height": 240, + "interpolation": "cv2_area", + } + conf = SimpleNamespace(**{**default_conf, **conf}) + device = "cuda" if torch.cuda.is_available() else "cpu" + + def preprocess(image: np.ndarray, conf: SimpleNamespace): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + if conf.resize_max: + scale = conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + if conf.force_resize: + image = resize_image(image, (conf.width, conf.height), "cv2_area") + size_new = (conf.width, conf.height) + scale = np.array(size) / np.array(size_new) + if conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // conf.dfactor * conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new, antialias=True) + input_ = image.to(device, non_blocking=True)[None] + data = { + "image": input_, + "image_orig": image_0, + "original_size": np.array(size), + "size": np.array(image.shape[1:][::-1]), + } + return data + + # convert to grayscale if needed + if len(image_0.shape) == 3 and conf.grayscale: + image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY) + else: + image0 = image_0 + # comment following lines, image is always RGB mode + # if not conf.grayscale and len(image_0.shape) == 3: + # image0 = image_0[:, :, ::-1] # BGR to RGB + data = preprocess(image0, conf) + pred = model({"image": data["image"]}) + pred["image_size"] = data["original_size"] + pred = {**pred, **data} + return pred + + +@torch.no_grad() +def main( + conf: Dict, + image_dir: Path, + export_dir: Optional[Path] = None, + as_half: bool = True, + image_list: Optional[Union[Path, List[str]]] = None, + feature_path: Optional[Path] = None, + overwrite: bool = False, +) -> Path: + logger.info( + "Extracting local features with configuration:" f"\n{pprint.pformat(conf)}" + ) + + dataset = ImageDataset(image_dir, conf["preprocessing"], image_list) + if feature_path is None: + feature_path = Path(export_dir, conf["output"] + ".h5") + feature_path.parent.mkdir(exist_ok=True, parents=True) + skip_names = set( + list_h5_names(feature_path) if feature_path.exists() and not overwrite else () + ) + dataset.names = [n for n in dataset.names if n not in skip_names] + if len(dataset.names) == 0: + logger.info("Skipping the extraction.") + return feature_path + + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(extractors, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + loader = torch.utils.data.DataLoader( + dataset, num_workers=1, shuffle=False, pin_memory=True + ) + for idx, data in enumerate(tqdm(loader)): + name = dataset.names[idx] + pred = model({"image": data["image"].to(device, non_blocking=True)}) + pred = {k: v[0].cpu().numpy() for k, v in pred.items()} + + pred["image_size"] = original_size = data["original_size"][0].numpy() + if "keypoints" in pred: + size = np.array(data["image"].shape[-2:][::-1]) + scales = (original_size / size).astype(np.float32) + pred["keypoints"] = (pred["keypoints"] + 0.5) * scales[None] - 0.5 + if "scales" in pred: + pred["scales"] *= scales.mean() + # add keypoint uncertainties scaled to the original resolution + uncertainty = getattr(model, "detection_noise", 1) * scales.mean() + + if as_half: + for k in pred: + dt = pred[k].dtype + if (dt == np.float32) and (dt != np.float16): + pred[k] = pred[k].astype(np.float16) + + with h5py.File(str(feature_path), "a", libver="latest") as fd: + try: + if name in fd: + del fd[name] + grp = fd.create_group(name) + for k, v in pred.items(): + grp.create_dataset(k, data=v) + if "keypoints" in pred: + grp["keypoints"].attrs["uncertainty"] = uncertainty + except OSError as error: + if "No space left on device" in error.args[0]: + logger.error( + "Out of disk space: storing features on disk can take " + "significant space, did you enable the as_half flag?" + ) + del grp, fd[name] + raise error + + del pred + + logger.info("Finished exporting features.") + return feature_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image_dir", type=Path, required=True) + parser.add_argument("--export_dir", type=Path, required=True) + parser.add_argument( + "--conf", + type=str, + default="superpoint_aachen", + choices=list(confs.keys()), + ) + parser.add_argument("--as_half", action="store_true") + parser.add_argument("--image_list", type=Path) + parser.add_argument("--feature_path", type=Path) + args = parser.parse_args() + main(confs[args.conf], args.image_dir, args.export_dir, args.as_half) diff --git a/imcui/hloc/extractors/__init__.py b/imcui/hloc/extractors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/extractors/alike.py b/imcui/hloc/extractors/alike.py new file mode 100644 index 0000000000000000000000000000000000000000..64724dc035c98ce01fa0bbb98b4772a993eb1526 --- /dev/null +++ b/imcui/hloc/extractors/alike.py @@ -0,0 +1,61 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger + +from ..utils.base_model import BaseModel + +alike_path = Path(__file__).parent / "../../third_party/ALIKE" +sys.path.append(str(alike_path)) +from alike import ALike as Alike_ +from alike import configs + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Alike(BaseModel): + default_conf = { + "model_name": "alike-t", # 'alike-t', 'alike-s', 'alike-n', 'alike-l' + "use_relu": True, + "multiscale": False, + "max_keypoints": 1000, + "detection_threshold": 0.5, + "top_k": -1, + "sub_pixel": False, + } + + required_inputs = ["image"] + + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}.pth".format(Path(__file__).stem, self.conf["model_name"]), + ) + logger.info("Loaded Alike model from {}".format(model_path)) + configs[conf["model_name"]]["model_path"] = model_path + self.net = Alike_( + **configs[conf["model_name"]], + device=device, + top_k=conf["top_k"], + scores_th=conf["detection_threshold"], + n_limit=conf["max_keypoints"], + ) + logger.info("Load Alike model done.") + + def _forward(self, data): + image = data["image"] + image = image.permute(0, 2, 3, 1).squeeze() + image = image.cpu().numpy() * 255.0 + pred = self.net(image, sub_pixel=self.conf["sub_pixel"]) + + keypoints = pred["keypoints"] + descriptors = pred["descriptors"] + scores = pred["scores"] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/imcui/hloc/extractors/aliked.py b/imcui/hloc/extractors/aliked.py new file mode 100644 index 0000000000000000000000000000000000000000..4f712bebd7c8a1a8052cff22064f19c0a7b13615 --- /dev/null +++ b/imcui/hloc/extractors/aliked.py @@ -0,0 +1,32 @@ +import sys +from pathlib import Path + +from ..utils.base_model import BaseModel + +lightglue_path = Path(__file__).parent / "../../third_party/LightGlue" +sys.path.append(str(lightglue_path)) + +from lightglue import ALIKED as ALIKED_ + + +class ALIKED(BaseModel): + default_conf = { + "model_name": "aliked-n16", + "max_num_keypoints": -1, + "detection_threshold": 0.2, + "nms_radius": 2, + } + required_inputs = ["image"] + + def _init(self, conf): + conf.pop("name") + self.model = ALIKED_(**conf) + + def _forward(self, data): + features = self.model(data) + + return { + "keypoints": [f for f in features["keypoints"]], + "scores": [f for f in features["keypoint_scores"]], + "descriptors": [f.t() for f in features["descriptors"]], + } diff --git a/imcui/hloc/extractors/cosplace.py b/imcui/hloc/extractors/cosplace.py new file mode 100644 index 0000000000000000000000000000000000000000..8d13a84d57d80bee090709623cce74453784844b --- /dev/null +++ b/imcui/hloc/extractors/cosplace.py @@ -0,0 +1,44 @@ +""" +Code for loading models trained with CosPlace as a global features extractor +for geolocalization through image retrieval. +Multiple models are available with different backbones. Below is a summary of +models available (backbone : list of available output descriptors +dimensionality). For example you can use a model based on a ResNet50 with +descriptors dimensionality 1024. + ResNet18: [32, 64, 128, 256, 512] + ResNet50: [32, 64, 128, 256, 512, 1024, 2048] + ResNet101: [32, 64, 128, 256, 512, 1024, 2048] + ResNet152: [32, 64, 128, 256, 512, 1024, 2048] + VGG16: [ 64, 128, 256, 512] + +CosPlace paper: https://arxiv.org/abs/2204.02287 +""" + +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + + +class CosPlace(BaseModel): + default_conf = {"backbone": "ResNet50", "fc_output_dim": 2048} + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "gmberton/CosPlace", + "get_trained_model", + backbone=conf["backbone"], + fc_output_dim=conf["fc_output_dim"], + ).eval() + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + self.norm_rgb = tvf.Normalize(mean=mean, std=std) + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + desc = self.net(image) + return { + "global_descriptor": desc, + } diff --git a/imcui/hloc/extractors/d2net.py b/imcui/hloc/extractors/d2net.py new file mode 100644 index 0000000000000000000000000000000000000000..207977c732e14ae6fde1e02d3e7f4335fbdf57e9 --- /dev/null +++ b/imcui/hloc/extractors/d2net.py @@ -0,0 +1,60 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +d2net_path = Path(__file__).parent / "../../third_party/d2net" +sys.path.append(str(d2net_path)) +from lib.model_test import D2Net as _D2Net +from lib.pyramid import process_multiscale + + +class D2Net(BaseModel): + default_conf = { + "model_name": "d2_tf.pth", + "checkpoint_dir": d2net_path / "models", + "use_relu": True, + "multiscale": False, + "max_keypoints": 1024, + } + required_inputs = ["image"] + + def _init(self, conf): + logger.info("Loading D2Net model...") + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + logger.info(f"Loading model from {model_path}...") + self.net = _D2Net( + model_file=model_path, use_relu=conf["use_relu"], use_cuda=False + ) + logger.info("Load D2Net model done.") + + def _forward(self, data): + image = data["image"] + image = image.flip(1) # RGB -> BGR + norm = image.new_tensor([103.939, 116.779, 123.68]) + image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization + + if self.conf["multiscale"]: + keypoints, scores, descriptors = process_multiscale(image, self.net) + else: + keypoints, scores, descriptors = process_multiscale( + image, self.net, scales=[1] + ) + keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale + + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[idxs] + scores = scores[idxs] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/imcui/hloc/extractors/darkfeat.py b/imcui/hloc/extractors/darkfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..8833041e9a168f465df0d07191245777612da890 --- /dev/null +++ b/imcui/hloc/extractors/darkfeat.py @@ -0,0 +1,44 @@ +import sys +from pathlib import Path + +from .. import MODEL_REPO_ID, logger + +from ..utils.base_model import BaseModel + +darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat" +sys.path.append(str(darkfeat_path)) +from darkfeat import DarkFeat as DarkFeat_ + + +class DarkFeat(BaseModel): + default_conf = { + "model_name": "DarkFeat.pth", + "max_keypoints": 1000, + "detection_threshold": 0.5, + "sub_pixel": False, + } + required_inputs = ["image"] + + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + logger.info("Loaded DarkFeat model: {}".format(model_path)) + self.net = DarkFeat_(model_path) + logger.info("Load DarkFeat model done.") + + def _forward(self, data): + pred = self.net({"image": data["image"]}) + keypoints = pred["keypoints"] + descriptors = pred["descriptors"] + scores = pred["scores"] + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[:, idxs] + scores = scores[idxs] + return { + "keypoints": keypoints[None], # 1 x N x 2 + "scores": scores[None], # 1 x N + "descriptors": descriptors[None], # 1 x 128 x N + } diff --git a/imcui/hloc/extractors/dedode.py b/imcui/hloc/extractors/dedode.py new file mode 100644 index 0000000000000000000000000000000000000000..a7108e31340535afcef062c1d8eb495014b70ee1 --- /dev/null +++ b/imcui/hloc/extractors/dedode.py @@ -0,0 +1,86 @@ +import sys +from pathlib import Path + +import torch +import torchvision.transforms as transforms + +from .. import MODEL_REPO_ID, logger + +from ..utils.base_model import BaseModel + +dedode_path = Path(__file__).parent / "../../third_party/DeDoDe" +sys.path.append(str(dedode_path)) + +from DeDoDe import dedode_descriptor_B, dedode_detector_L +from DeDoDe.utils import to_pixel_coords + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class DeDoDe(BaseModel): + default_conf = { + "name": "dedode", + "model_detector_name": "dedode_detector_L.pth", + "model_descriptor_name": "dedode_descriptor_B.pth", + "max_keypoints": 2000, + "match_threshold": 0.2, + "dense": False, # Now fixed to be false + } + required_inputs = [ + "image", + ] + + # Initialize the line matcher + def _init(self, conf): + model_detector_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, conf["model_detector_name"]), + ) + model_descriptor_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, conf["model_descriptor_name"]), + ) + logger.info("Loaded DarkFeat model: {}".format(model_detector_path)) + self.normalizer = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + # load the model + weights_detector = torch.load(model_detector_path, map_location="cpu") + weights_descriptor = torch.load(model_descriptor_path, map_location="cpu") + self.detector = dedode_detector_L(weights=weights_detector, device=device) + self.descriptor = dedode_descriptor_B(weights=weights_descriptor, device=device) + logger.info("Load DeDoDe model done.") + + def _forward(self, data): + """ + data: dict, keys: {'image0','image1'} + image shape: N x C x H x W + color mode: RGB + """ + img0 = self.normalizer(data["image"].squeeze()).float()[None] + H_A, W_A = img0.shape[2:] + + # step 1: detect keypoints + detections_A = None + batch_A = {"image": img0} + if self.conf["dense"]: + detections_A = self.detector.detect_dense(batch_A) + else: + detections_A = self.detector.detect( + batch_A, num_keypoints=self.conf["max_keypoints"] + ) + keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"] + + # step 2: describe keypoints + # dim: 1 x N x 256 + description_A = self.descriptor.describe_keypoints(batch_A, keypoints_A)[ + "descriptions" + ] + keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A) + + return { + "keypoints": keypoints_A, # 1 x N x 2 + "descriptors": description_A.permute(0, 2, 1), # 1 x 256 x N + "scores": P_A, # 1 x N + } diff --git a/imcui/hloc/extractors/dir.py b/imcui/hloc/extractors/dir.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7322a922a151b0a5ad5e185fbb312a0b5d12a7 --- /dev/null +++ b/imcui/hloc/extractors/dir.py @@ -0,0 +1,78 @@ +import os +import sys +from pathlib import Path +from zipfile import ZipFile + +import gdown +import sklearn +import torch + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party/deep-image-retrieval")) +os.environ["DB_ROOT"] = "" # required by dirtorch + +from dirtorch.extract_features import load_model # noqa: E402 +from dirtorch.utils import common # noqa: E402 + +# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca, +# which has been deprecated in sklearn v0.24 +# and must be explicitly imported with `from sklearn.decomposition import PCA`. +# This is a hacky workaround to maintain forward compatibility. +sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca + + +class DIR(BaseModel): + default_conf = { + "model_name": "Resnet-101-AP-GeM", + "whiten_name": "Landmarks_clean", + "whiten_params": { + "whitenp": 0.25, + "whitenv": None, + "whitenm": 1.0, + }, + "pooling": "gem", + "gemp": 3, + } + required_inputs = ["image"] + + dir_models = { + "Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy", + } + + def _init(self, conf): + # todo: download from google drive -> huggingface models + checkpoint = Path(torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt") + if not checkpoint.exists(): + checkpoint.parent.mkdir(exist_ok=True, parents=True) + link = self.dir_models[conf["model_name"]] + gdown.download(str(link), str(checkpoint) + ".zip", quiet=False) + zf = ZipFile(str(checkpoint) + ".zip", "r") + zf.extractall(checkpoint.parent) + zf.close() + os.remove(str(checkpoint) + ".zip") + + self.net = load_model(checkpoint, False) # first load on CPU + if conf["whiten_name"]: + assert conf["whiten_name"] in self.net.pca + + def _forward(self, data): + image = data["image"] + assert image.shape[1] == 3 + mean = self.net.preprocess["mean"] + std = self.net.preprocess["std"] + image = image - image.new_tensor(mean)[:, None, None] + image = image / image.new_tensor(std)[:, None, None] + + desc = self.net(image) + desc = desc.unsqueeze(0) # batch dimension + if self.conf["whiten_name"]: + pca = self.net.pca[self.conf["whiten_name"]] + desc = common.whiten_features( + desc.cpu().numpy(), pca, **self.conf["whiten_params"] + ) + desc = torch.from_numpy(desc) + + return { + "global_descriptor": desc, + } diff --git a/imcui/hloc/extractors/disk.py b/imcui/hloc/extractors/disk.py new file mode 100644 index 0000000000000000000000000000000000000000..a062a908af68656c29e7ee1e8c5047c92790bcc9 --- /dev/null +++ b/imcui/hloc/extractors/disk.py @@ -0,0 +1,35 @@ +import kornia + +from .. import logger + +from ..utils.base_model import BaseModel + + +class DISK(BaseModel): + default_conf = { + "weights": "depth", + "max_keypoints": None, + "nms_window_size": 5, + "detection_threshold": 0.0, + "pad_if_not_divisible": True, + } + required_inputs = ["image"] + + def _init(self, conf): + self.model = kornia.feature.DISK.from_pretrained(conf["weights"]) + logger.info("Load DISK model done.") + + def _forward(self, data): + image = data["image"] + features = self.model( + image, + n=self.conf["max_keypoints"], + window_size=self.conf["nms_window_size"], + score_threshold=self.conf["detection_threshold"], + pad_if_not_divisible=self.conf["pad_if_not_divisible"], + ) + return { + "keypoints": [f.keypoints for f in features][0][None], + "scores": [f.detection_scores for f in features][0][None], + "descriptors": [f.descriptors.t() for f in features][0][None], + } diff --git a/imcui/hloc/extractors/dog.py b/imcui/hloc/extractors/dog.py new file mode 100644 index 0000000000000000000000000000000000000000..b280bbc42376f3af827002bb85ff4996ccdf50b4 --- /dev/null +++ b/imcui/hloc/extractors/dog.py @@ -0,0 +1,135 @@ +import kornia +import numpy as np +import pycolmap +import torch +from kornia.feature.laf import ( + extract_patches_from_pyramid, + laf_from_center_scale_ori, +) + +from ..utils.base_model import BaseModel + +EPS = 1e-6 + + +def sift_to_rootsift(x): + x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS) + x = np.sqrt(x.clip(min=EPS)) + x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS) + return x + + +class DoG(BaseModel): + default_conf = { + "options": { + "first_octave": 0, + "peak_threshold": 0.01, + }, + "descriptor": "rootsift", + "max_keypoints": -1, + "patch_size": 32, + "mr_size": 12, + } + required_inputs = ["image"] + detection_noise = 1.0 + max_batch_size = 1024 + + def _init(self, conf): + if conf["descriptor"] == "sosnet": + self.describe = kornia.feature.SOSNet(pretrained=True) + elif conf["descriptor"] == "hardnet": + self.describe = kornia.feature.HardNet(pretrained=True) + elif conf["descriptor"] not in ["sift", "rootsift"]: + raise ValueError(f'Unknown descriptor: {conf["descriptor"]}') + + self.sift = None # lazily instantiated on the first image + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.device = torch.device("cpu") + + def to(self, *args, **kwargs): + device = kwargs.get("device") + if device is None: + match = [a for a in args if isinstance(a, (torch.device, str))] + if len(match) > 0: + device = match[0] + if device is not None: + self.device = torch.device(device) + return super().to(*args, **kwargs) + + def _forward(self, data): + image = data["image"] + image_np = image.cpu().numpy()[0, 0] + assert image.shape[1] == 1 + assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS + + if self.sift is None: + device = self.dummy_param.device + use_gpu = pycolmap.has_cuda and device.type == "cuda" + options = {**self.conf["options"]} + if self.conf["descriptor"] == "rootsift": + options["normalization"] = pycolmap.Normalization.L1_ROOT + else: + options["normalization"] = pycolmap.Normalization.L2 + self.sift = pycolmap.Sift( + options=pycolmap.SiftExtractionOptions(options), + device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"), + ) + keypoints, descriptors = self.sift.extract(image_np) + scales = keypoints[:, 2] + oris = np.rad2deg(keypoints[:, 3]) + + if self.conf["descriptor"] in ["sift", "rootsift"]: + # We still renormalize because COLMAP does not normalize well, + # maybe due to numerical errors + if self.conf["descriptor"] == "rootsift": + descriptors = sift_to_rootsift(descriptors) + descriptors = torch.from_numpy(descriptors) + elif self.conf["descriptor"] in ("sosnet", "hardnet"): + center = keypoints[:, :2] + 0.5 + laf_scale = scales * self.conf["mr_size"] / 2 + laf_ori = -oris + lafs = laf_from_center_scale_ori( + torch.from_numpy(center)[None], + torch.from_numpy(laf_scale)[None, :, None, None], + torch.from_numpy(laf_ori)[None, :, None], + ).to(image.device) + patches = extract_patches_from_pyramid( + image, lafs, PS=self.conf["patch_size"] + )[0] + descriptors = patches.new_zeros((len(patches), 128)) + if len(patches) > 0: + for start_idx in range(0, len(patches), self.max_batch_size): + end_idx = min(len(patches), start_idx + self.max_batch_size) + descriptors[start_idx:end_idx] = self.describe( + patches[start_idx:end_idx] + ) + else: + raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}') + + keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y + scales = torch.from_numpy(scales) + oris = torch.from_numpy(oris) + scores = keypoints.new_zeros(len(keypoints)) # no scores for SIFT yet + + if self.conf["max_keypoints"] != -1: + # TODO: check that the scores from PyCOLMAP are 100% correct, + # follow https://github.com/mihaidusmanu/pycolmap/issues/8 + max_number = ( + scores.shape[0] + if scores.shape[0] < self.conf["max_keypoints"] + else self.conf["max_keypoints"] + ) + values, indices = torch.topk(scores, max_number) + keypoints = keypoints[indices] + scales = scales[indices] + oris = oris[indices] + scores = scores[indices] + descriptors = descriptors[indices] + + return { + "keypoints": keypoints[None], + "scales": scales[None], + "oris": oris[None], + "scores": scores[None], + "descriptors": descriptors.T[None], + } diff --git a/imcui/hloc/extractors/eigenplaces.py b/imcui/hloc/extractors/eigenplaces.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9953b27c00682c830842736fd0bdab93857f14 --- /dev/null +++ b/imcui/hloc/extractors/eigenplaces.py @@ -0,0 +1,57 @@ +""" +Code for loading models trained with EigenPlaces (or CosPlace) as a global +features extractor for geolocalization through image retrieval. +Multiple models are available with different backbones. Below is a summary of +models available (backbone : list of available output descriptors +dimensionality). For example you can use a model based on a ResNet50 with +descriptors dimensionality 1024. + +EigenPlaces trained models: + ResNet18: [ 256, 512] + ResNet50: [128, 256, 512, 2048] + ResNet101: [128, 256, 512, 2048] + VGG16: [ 512] + +CosPlace trained models: + ResNet18: [32, 64, 128, 256, 512] + ResNet50: [32, 64, 128, 256, 512, 1024, 2048] + ResNet101: [32, 64, 128, 256, 512, 1024, 2048] + ResNet152: [32, 64, 128, 256, 512, 1024, 2048] + VGG16: [ 64, 128, 256, 512] + +EigenPlaces paper (ICCV 2023): https://arxiv.org/abs/2308.10832 +CosPlace paper (CVPR 2022): https://arxiv.org/abs/2204.02287 +""" + +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + + +class EigenPlaces(BaseModel): + default_conf = { + "variant": "EigenPlaces", + "backbone": "ResNet101", + "fc_output_dim": 2048, + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "gmberton/" + conf["variant"], + "get_trained_model", + backbone=conf["backbone"], + fc_output_dim=conf["fc_output_dim"], + ).eval() + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + self.norm_rgb = tvf.Normalize(mean=mean, std=std) + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + desc = self.net(image) + return { + "global_descriptor": desc, + } diff --git a/imcui/hloc/extractors/example.py b/imcui/hloc/extractors/example.py new file mode 100644 index 0000000000000000000000000000000000000000..3d952c4014e006d74409a8f32ee7159d58305de5 --- /dev/null +++ b/imcui/hloc/extractors/example.py @@ -0,0 +1,56 @@ +import sys +from pathlib import Path + +import torch + +from .. import logger +from ..utils.base_model import BaseModel + +example_path = Path(__file__).parent / "../../third_party/example" +sys.path.append(str(example_path)) + +# import some modules here + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Example(BaseModel): + # change to your default configs + default_conf = { + "name": "example", + "keypoint_threshold": 0.1, + "max_keypoints": 2000, + "model_name": "model.pth", + } + required_inputs = ["image"] + + def _init(self, conf): + # set checkpoints paths if needed + model_path = example_path / "checkpoints" / f'{conf["model_name"]}' + if not model_path.exists(): + logger.info(f"No model found at {model_path}") + + # init model + self.net = callable + # self.net = ExampleNet(is_test=True) + state_dict = torch.load(model_path, map_location="cpu") + self.net.load_state_dict(state_dict["model_state"]) + logger.info("Load example model done.") + + def _forward(self, data): + # data: dict, keys: 'image' + # image color mode: RGB + # image value range in [0, 1] + image = data["image"] + + # B: batch size, N: number of keypoints + # keypoints shape: B x N x 2, type: torch tensor + # scores shape: B x N, type: torch tensor + # descriptors shape: B x 128 x N, type: torch tensor + keypoints, scores, descriptors = self.net(image) + + return { + "keypoints": keypoints, + "scores": scores, + "descriptors": descriptors, + } diff --git a/imcui/hloc/extractors/fire.py b/imcui/hloc/extractors/fire.py new file mode 100644 index 0000000000000000000000000000000000000000..980f18e63d1a395835891c8e6595cfc66c21db2d --- /dev/null +++ b/imcui/hloc/extractors/fire.py @@ -0,0 +1,72 @@ +import logging +import subprocess +import sys +from pathlib import Path + +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + +logger = logging.getLogger(__name__) +fire_path = Path(__file__).parent / "../../third_party/fire" +sys.path.append(str(fire_path)) + + +import fire_network + + +class FIRe(BaseModel): + default_conf = { + "global": True, + "asmk": False, + "model_name": "fire_SfM_120k.pth", + "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params + "features_num": 1000, # TODO:not supported now + "asmk_name": "asmk_codebook.bin", # TODO:not supported now + "config_name": "eval_fire.yml", + } + required_inputs = ["image"] + + # Models exported using + fire_models = { + "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth", + "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth", + } + + def _init(self, conf): + assert conf["model_name"] in self.fire_models.keys() + # Config paths + model_path = fire_path / "model" / conf["model_name"] + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + link = self.fire_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_path)] + logger.info(f"Downloading the FIRe model with `{cmd}`.") + subprocess.run(cmd, check=True) + + logger.info("Loading fire model...") + + # Load net + state = torch.load(model_path) + state["net_params"]["pretrained"] = None + net = fire_network.init_network(**state["net_params"]) + net.load_state_dict(state["state_dict"]) + self.net = net + + self.norm_rgb = tvf.Normalize( + **dict(zip(["mean", "std"], net.runtime["mean_std"])) + ) + + # params + self.scales = conf["scales"] + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + + # Feature extraction. + desc = self.net.forward_global(image, scales=self.scales) + + return {"global_descriptor": desc} diff --git a/imcui/hloc/extractors/fire_local.py b/imcui/hloc/extractors/fire_local.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e9ba9f4c3d86280e8232f61263b729ccb933be --- /dev/null +++ b/imcui/hloc/extractors/fire_local.py @@ -0,0 +1,84 @@ +import subprocess +import sys +from pathlib import Path + +import torch +import torchvision.transforms as tvf + +from .. import logger +from ..utils.base_model import BaseModel + +fire_path = Path(__file__).parent / "../../third_party/fire" + +sys.path.append(str(fire_path)) + + +import fire_network + +EPS = 1e-6 + + +class FIRe(BaseModel): + default_conf = { + "global": True, + "asmk": False, + "model_name": "fire_SfM_120k.pth", + "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params + "features_num": 1000, + "asmk_name": "asmk_codebook.bin", + "config_name": "eval_fire.yml", + } + required_inputs = ["image"] + + # Models exported using + fire_models = { + "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth", + "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth", + } + + def _init(self, conf): + assert conf["model_name"] in self.fire_models.keys() + + # Config paths + model_path = fire_path / "model" / conf["model_name"] + config_path = fire_path / conf["config_name"] # noqa: F841 + asmk_bin_path = fire_path / "model" / conf["asmk_name"] # noqa: F841 + + # Download the model. + if not model_path.exists(): + model_path.parent.mkdir(exist_ok=True) + link = self.fire_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(model_path)] + logger.info(f"Downloading the FIRe model with `{cmd}`.") + subprocess.run(cmd, check=True) + + logger.info("Loading fire model...") + + # Load net + state = torch.load(model_path) + state["net_params"]["pretrained"] = None + net = fire_network.init_network(**state["net_params"]) + net.load_state_dict(state["state_dict"]) + self.net = net + + self.norm_rgb = tvf.Normalize( + **dict(zip(["mean", "std"], net.runtime["mean_std"])) + ) + + # params + self.scales = conf["scales"] + self.features_num = conf["features_num"] + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + + local_desc = self.net.forward_local( + image, features_num=self.features_num, scales=self.scales + ) + + logger.info(f"output[0].shape = {local_desc[0].shape}\n") + + return { + # 'global_descriptor': desc + "local_descriptor": local_desc + } diff --git a/imcui/hloc/extractors/lanet.py b/imcui/hloc/extractors/lanet.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f7af8692d9c216bd613fe2cf488e3c148392fa --- /dev/null +++ b/imcui/hloc/extractors/lanet.py @@ -0,0 +1,63 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger + +from ..utils.base_model import BaseModel + +lib_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(lib_path)) +from lanet.network_v0.model import PointModel + +lanet_path = Path(__file__).parent / "../../third_party/lanet" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class LANet(BaseModel): + default_conf = { + "model_name": "PointModel_v0.pth", + "keypoint_threshold": 0.1, + "max_keypoints": 1024, + } + required_inputs = ["image"] + + def _init(self, conf): + logger.info("Loading LANet model...") + + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + self.net = PointModel(is_test=True) + state_dict = torch.load(model_path, map_location="cpu") + self.net.load_state_dict(state_dict["model_state"]) + logger.info("Load LANet model done.") + + def _forward(self, data): + image = data["image"] + keypoints, scores, descriptors = self.net(image) + _, _, Hc, Wc = descriptors.shape + + # Scores & Descriptors + kpts_score = torch.cat([keypoints, scores], dim=1).view(3, -1).t() + descriptors = descriptors.view(256, Hc, Wc).view(256, -1).t() + + # Filter based on confidence threshold + descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :] + kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :] + keypoints = kpts_score[:, 1:] + scores = kpts_score[:, 0] + + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[idxs] + scores = scores[idxs] + + return { + "keypoints": keypoints[None], + "scores": scores[None], + "descriptors": descriptors.T[None], + } diff --git a/imcui/hloc/extractors/netvlad.py b/imcui/hloc/extractors/netvlad.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba5f9a2feebf7ed0accd23e318b2a83e0f9df12 --- /dev/null +++ b/imcui/hloc/extractors/netvlad.py @@ -0,0 +1,146 @@ +import subprocess +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +from scipy.io import loadmat + +from .. import logger +from ..utils.base_model import BaseModel + +EPS = 1e-6 + + +class NetVLADLayer(nn.Module): + def __init__(self, input_dim=512, K=64, score_bias=False, intranorm=True): + super().__init__() + self.score_proj = nn.Conv1d(input_dim, K, kernel_size=1, bias=score_bias) + centers = nn.parameter.Parameter(torch.empty([input_dim, K])) + nn.init.xavier_uniform_(centers) + self.register_parameter("centers", centers) + self.intranorm = intranorm + self.output_dim = input_dim * K + + def forward(self, x): + b = x.size(0) + scores = self.score_proj(x) + scores = F.softmax(scores, dim=1) + diff = x.unsqueeze(2) - self.centers.unsqueeze(0).unsqueeze(-1) + desc = (scores.unsqueeze(1) * diff).sum(dim=-1) + if self.intranorm: + # From the official MATLAB implementation. + desc = F.normalize(desc, dim=1) + desc = desc.view(b, -1) + desc = F.normalize(desc, dim=1) + return desc + + +class NetVLAD(BaseModel): + default_conf = {"model_name": "VGG16-NetVLAD-Pitts30K", "whiten": True} + required_inputs = ["image"] + + # Models exported using + # https://github.com/uzh-rpg/netvlad_tf_open/blob/master/matlab/net_class2struct.m. + dir_models = { + "VGG16-NetVLAD-Pitts30K": "https://cvg-data.inf.ethz.ch/hloc/netvlad/Pitts30K_struct.mat", + "VGG16-NetVLAD-TokyoTM": "https://cvg-data.inf.ethz.ch/hloc/netvlad/TokyoTM_struct.mat", + } + + def _init(self, conf): + assert conf["model_name"] in self.dir_models.keys() + + # Download the checkpoint. + checkpoint = Path(torch.hub.get_dir(), "netvlad", conf["model_name"] + ".mat") + if not checkpoint.exists(): + checkpoint.parent.mkdir(exist_ok=True, parents=True) + link = self.dir_models[conf["model_name"]] + cmd = ["wget", "--quiet", link, "-O", str(checkpoint)] + logger.info(f"Downloading the NetVLAD model with `{cmd}`.") + subprocess.run(cmd, check=True) + + # Create the network. + # Remove classification head. + backbone = list(models.vgg16().children())[0] + # Remove last ReLU + MaxPool2d. + self.backbone = nn.Sequential(*list(backbone.children())[:-2]) + + self.netvlad = NetVLADLayer() + + if conf["whiten"]: + self.whiten = nn.Linear(self.netvlad.output_dim, 4096) + + # Parse MATLAB weights using https://github.com/uzh-rpg/netvlad_tf_open + mat = loadmat(checkpoint, struct_as_record=False, squeeze_me=True) + + # CNN weights. + for layer, mat_layer in zip(self.backbone.children(), mat["net"].layers): + if isinstance(layer, nn.Conv2d): + w = mat_layer.weights[0] # Shape: S x S x IN x OUT + b = mat_layer.weights[1] # Shape: OUT + # Prepare for PyTorch - enforce float32 and right shape. + # w should have shape: OUT x IN x S x S + # b should have shape: OUT + w = torch.tensor(w).float().permute([3, 2, 0, 1]) + b = torch.tensor(b).float() + # Update layer weights. + layer.weight = nn.Parameter(w) + layer.bias = nn.Parameter(b) + + # NetVLAD weights. + score_w = mat["net"].layers[30].weights[0] # D x K + # centers are stored as opposite in official MATLAB code + center_w = -mat["net"].layers[30].weights[1] # D x K + # Prepare for PyTorch - make sure it is float32 and has right shape. + # score_w should have shape K x D x 1 + # center_w should have shape D x K + score_w = torch.tensor(score_w).float().permute([1, 0]).unsqueeze(-1) + center_w = torch.tensor(center_w).float() + # Update layer weights. + self.netvlad.score_proj.weight = nn.Parameter(score_w) + self.netvlad.centers = nn.Parameter(center_w) + + # Whitening weights. + if conf["whiten"]: + w = mat["net"].layers[33].weights[0] # Shape: 1 x 1 x IN x OUT + b = mat["net"].layers[33].weights[1] # Shape: OUT + # Prepare for PyTorch - make sure it is float32 and has right shape + w = torch.tensor(w).float().squeeze().permute([1, 0]) # OUT x IN + b = torch.tensor(b.squeeze()).float() # Shape: OUT + # Update layer weights. + self.whiten.weight = nn.Parameter(w) + self.whiten.bias = nn.Parameter(b) + + # Preprocessing parameters. + self.preprocess = { + "mean": mat["net"].meta.normalization.averageImage[0, 0], + "std": np.array([1, 1, 1], dtype=np.float32), + } + + def _forward(self, data): + image = data["image"] + assert image.shape[1] == 3 + assert image.min() >= -EPS and image.max() <= 1 + EPS + image = torch.clamp(image * 255, 0.0, 255.0) # Input should be 0-255. + mean = self.preprocess["mean"] + std = self.preprocess["std"] + image = image - image.new_tensor(mean).view(1, -1, 1, 1) + image = image / image.new_tensor(std).view(1, -1, 1, 1) + + # Feature extraction. + descriptors = self.backbone(image) + b, c, _, _ = descriptors.size() + descriptors = descriptors.view(b, c, -1) + + # NetVLAD layer. + descriptors = F.normalize(descriptors, dim=1) # Pre-normalization. + desc = self.netvlad(descriptors) + + # Whiten if needed. + if hasattr(self, "whiten"): + desc = self.whiten(desc) + desc = F.normalize(desc, dim=1) # Final L2 normalization. + + return {"global_descriptor": desc} diff --git a/imcui/hloc/extractors/openibl.py b/imcui/hloc/extractors/openibl.py new file mode 100644 index 0000000000000000000000000000000000000000..9e332a4e0016fceb184dd850bd3b6f86231dad54 --- /dev/null +++ b/imcui/hloc/extractors/openibl.py @@ -0,0 +1,26 @@ +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + + +class OpenIBL(BaseModel): + default_conf = { + "model_name": "vgg16_netvlad", + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "yxgeee/OpenIBL", conf["model_name"], pretrained=True + ).eval() + mean = [0.48501960784313836, 0.4579568627450961, 0.4076039215686255] + std = [0.00392156862745098, 0.00392156862745098, 0.00392156862745098] + self.norm_rgb = tvf.Normalize(mean=mean, std=std) + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + desc = self.net(image) + return { + "global_descriptor": desc, + } diff --git a/imcui/hloc/extractors/r2d2.py b/imcui/hloc/extractors/r2d2.py new file mode 100644 index 0000000000000000000000000000000000000000..66769b040dd10e1b4f38eca0cf41c2023d096482 --- /dev/null +++ b/imcui/hloc/extractors/r2d2.py @@ -0,0 +1,73 @@ +import sys +from pathlib import Path + +import torchvision.transforms as tvf + +from .. import MODEL_REPO_ID, logger + +from ..utils.base_model import BaseModel + +r2d2_path = Path(__file__).parents[2] / "third_party/r2d2" +sys.path.append(str(r2d2_path)) + +gim_path = Path(__file__).parents[2] / "third_party/gim" +if str(gim_path) in sys.path: + sys.path.remove(str(gim_path)) + +from extract import NonMaxSuppression, extract_multiscale, load_network + + +class R2D2(BaseModel): + default_conf = { + "model_name": "r2d2_WASF_N16.pt", + "max_keypoints": 5000, + "scale_factor": 2**0.25, + "min_size": 256, + "max_size": 1024, + "min_scale": 0, + "max_scale": 1, + "reliability_threshold": 0.7, + "repetability_threshold": 0.7, + } + required_inputs = ["image"] + + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + self.norm_rgb = tvf.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + self.net = load_network(model_path) + self.detector = NonMaxSuppression( + rel_thr=conf["reliability_threshold"], + rep_thr=conf["repetability_threshold"], + ) + logger.info("Load R2D2 model done.") + + def _forward(self, data): + img = data["image"] + img = self.norm_rgb(img) + + xys, desc, scores = extract_multiscale( + self.net, + img, + self.detector, + scale_f=self.conf["scale_factor"], + min_size=self.conf["min_size"], + max_size=self.conf["max_size"], + min_scale=self.conf["min_scale"], + max_scale=self.conf["max_scale"], + ) + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + xy = xys[idxs, :2] + desc = desc[idxs].t() + scores = scores[idxs] + + pred = { + "keypoints": xy[None], + "descriptors": desc[None], + "scores": scores[None], + } + return pred diff --git a/imcui/hloc/extractors/rekd.py b/imcui/hloc/extractors/rekd.py new file mode 100644 index 0000000000000000000000000000000000000000..82fc522920e21e171cb269e680506ad7aeeeaf9a --- /dev/null +++ b/imcui/hloc/extractors/rekd.py @@ -0,0 +1,60 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger + +from ..utils.base_model import BaseModel + +rekd_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(rekd_path)) +from REKD.training.model.REKD import REKD as REKD_ + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class REKD(BaseModel): + default_conf = { + "model_name": "v0", + "keypoint_threshold": 0.1, + } + required_inputs = ["image"] + + def _init(self, conf): + # TODO: download model + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + if not model_path.exists(): + print(f"No model found at {model_path}") + self.net = REKD_(is_test=True) + state_dict = torch.load(model_path, map_location="cpu") + self.net.load_state_dict(state_dict["model_state"]) + logger.info("Load REKD model done.") + + def _forward(self, data): + image = data["image"] + keypoints, scores, descriptors = self.net(image) + _, _, Hc, Wc = descriptors.shape + + # Scores & Descriptors + kpts_score = ( + torch.cat([keypoints, scores], dim=1).view(3, -1).t().cpu().detach().numpy() + ) + descriptors = ( + descriptors.view(256, Hc, Wc).view(256, -1).t().cpu().detach().numpy() + ) + + # Filter based on confidence threshold + descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :] + kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :] + keypoints = kpts_score[:, 1:] + scores = kpts_score[:, 0] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/imcui/hloc/extractors/rord.py b/imcui/hloc/extractors/rord.py new file mode 100644 index 0000000000000000000000000000000000000000..ba71113e4f9a57609879c95bb453af4104dbb72d --- /dev/null +++ b/imcui/hloc/extractors/rord.py @@ -0,0 +1,59 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger + +from ..utils.base_model import BaseModel + +rord_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(rord_path)) +from RoRD.lib.model_test import D2Net as _RoRD +from RoRD.lib.pyramid import process_multiscale + + +class RoRD(BaseModel): + default_conf = { + "model_name": "rord.pth", + "checkpoint_dir": rord_path / "RoRD" / "models", + "use_relu": True, + "multiscale": False, + "max_keypoints": 1024, + } + required_inputs = ["image"] + + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + self.net = _RoRD( + model_file=model_path, use_relu=conf["use_relu"], use_cuda=False + ) + logger.info("Load RoRD model done.") + + def _forward(self, data): + image = data["image"] + image = image.flip(1) # RGB -> BGR + norm = image.new_tensor([103.939, 116.779, 123.68]) + image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization + + if self.conf["multiscale"]: + keypoints, scores, descriptors = process_multiscale(image, self.net) + else: + keypoints, scores, descriptors = process_multiscale( + image, self.net, scales=[1] + ) + keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale + + idxs = scores.argsort()[-self.conf["max_keypoints"] or None :] + keypoints = keypoints[idxs, :2] + descriptors = descriptors[idxs] + scores = scores[idxs] + + return { + "keypoints": torch.from_numpy(keypoints)[None], + "scores": torch.from_numpy(scores)[None], + "descriptors": torch.from_numpy(descriptors.T)[None], + } diff --git a/imcui/hloc/extractors/sfd2.py b/imcui/hloc/extractors/sfd2.py new file mode 100644 index 0000000000000000000000000000000000000000..2724ed3046644da4903c54a7f0d9d8d8f585aafe --- /dev/null +++ b/imcui/hloc/extractors/sfd2.py @@ -0,0 +1,44 @@ +import sys +from pathlib import Path + +import torchvision.transforms as tvf + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.sfd2 import load_sfd2 + + +class SFD2(BaseModel): + default_conf = { + "max_keypoints": 4096, + "model_name": "sfd2_20230511_210205_resnet4x.79.pth", + "conf_th": 0.001, + } + required_inputs = ["image"] + + def _init(self, conf): + self.conf = {**self.default_conf, **conf} + self.norm_rgb = tvf.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format("pram", self.conf["model_name"]), + ) + self.net = load_sfd2(weight_path=model_path).eval() + + logger.info("Load SFD2 model done.") + + def _forward(self, data): + pred = self.net.extract_local_global( + data={"image": self.norm_rgb(data["image"])}, config=self.conf + ) + out = { + "keypoints": pred["keypoints"][0][None], + "scores": pred["scores"][0][None], + "descriptors": pred["descriptors"][0][None], + } + return out diff --git a/imcui/hloc/extractors/sift.py b/imcui/hloc/extractors/sift.py new file mode 100644 index 0000000000000000000000000000000000000000..05df8a76f18b7eae32ef52cbfc91fb13d37c2a9f --- /dev/null +++ b/imcui/hloc/extractors/sift.py @@ -0,0 +1,216 @@ +import warnings + +import cv2 +import numpy as np +import torch +from kornia.color import rgb_to_grayscale +from omegaconf import OmegaConf +from packaging import version + +try: + import pycolmap +except ImportError: + pycolmap = None +from .. import logger + +from ..utils.base_model import BaseModel + + +def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None): + h, w = image_shape + ij = np.round(points - 0.5).astype(int).T[::-1] + + # Remove duplicate points (identical coordinates). + # Pick highest scale or score + s = scales if scores is None else scores + buffer = np.zeros((h, w)) + np.maximum.at(buffer, tuple(ij), s) + keep = np.where(buffer[tuple(ij)] == s)[0] + + # Pick lowest angle (arbitrary). + ij = ij[:, keep] + buffer[:] = np.inf + o_abs = np.abs(angles[keep]) + np.minimum.at(buffer, tuple(ij), o_abs) + mask = buffer[tuple(ij)] == o_abs + ij = ij[:, mask] + keep = keep[mask] + + if nms_radius > 0: + # Apply NMS on the remaining points + buffer[:] = 0 + buffer[tuple(ij)] = s[keep] # scores or scale + + local_max = torch.nn.functional.max_pool2d( + torch.from_numpy(buffer).unsqueeze(0), + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ).squeeze(0) + is_local_max = buffer == local_max.numpy() + keep = keep[is_local_max[tuple(ij)]] + return keep + + +def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor: + x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps) + x.clip_(min=eps).sqrt_() + return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps) + + +def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray: + """ + Detect keypoints using OpenCV Detector. + Optionally, perform description. + Args: + features: OpenCV based keypoints detector and descriptor + image: Grayscale image of uint8 data type + Returns: + keypoints: 1D array of detected cv2.KeyPoint + scores: 1D array of responses + descriptors: 1D array of descriptors + """ + detections, descriptors = features.detectAndCompute(image, None) + points = np.array([k.pt for k in detections], dtype=np.float32) + scores = np.array([k.response for k in detections], dtype=np.float32) + scales = np.array([k.size for k in detections], dtype=np.float32) + angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32)) + return points, scores, scales, angles, descriptors + + +class SIFT(BaseModel): + default_conf = { + "rootsift": True, + "nms_radius": 0, # None to disable filtering entirely. + "max_keypoints": 4096, + "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda} + "detection_threshold": 0.0066667, # from COLMAP + "edge_threshold": 10, + "first_octave": -1, # only used by pycolmap, the default of COLMAP + "num_octaves": 4, + } + + required_data_keys = ["image"] + + def _init(self, conf): + self.conf = OmegaConf.create(self.conf) + backend = self.conf.backend + if backend.startswith("pycolmap"): + if pycolmap is None: + raise ImportError( + "Cannot find module pycolmap: install it with pip" + "or use backend=opencv." + ) + options = { + "peak_threshold": self.conf.detection_threshold, + "edge_threshold": self.conf.edge_threshold, + "first_octave": self.conf.first_octave, + "num_octaves": self.conf.num_octaves, + "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy. + } + device = ( + "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "") + ) + if ( + backend == "pycolmap_cpu" or not pycolmap.has_cuda + ) and pycolmap.__version__ < "0.5.0": + warnings.warn( + "The pycolmap CPU SIFT is buggy in version < 0.5.0, " + "consider upgrading pycolmap or use the CUDA version.", + stacklevel=1, + ) + else: + options["max_num_features"] = self.conf.max_keypoints + self.sift = pycolmap.Sift(options=options, device=device) + elif backend == "opencv": + self.sift = cv2.SIFT_create( + contrastThreshold=self.conf.detection_threshold, + nfeatures=self.conf.max_keypoints, + edgeThreshold=self.conf.edge_threshold, + nOctaveLayers=self.conf.num_octaves, + ) + else: + backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"} + raise ValueError( + f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}." + ) + logger.info("Load SIFT model done.") + + def extract_single_image(self, image: torch.Tensor): + image_np = image.cpu().numpy().squeeze(0) + + if self.conf.backend.startswith("pycolmap"): + if version.parse(pycolmap.__version__) >= version.parse("0.5.0"): + detections, descriptors = self.sift.extract(image_np) + scores = None # Scores are not exposed by COLMAP anymore. + else: + detections, scores, descriptors = self.sift.extract(image_np) + keypoints = detections[:, :2] # Keep only (x, y). + scales, angles = detections[:, -2:].T + if scores is not None and ( + self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda + ): + # Set the scores as a combination of abs. response and scale. + scores = np.abs(scores) * scales + elif self.conf.backend == "opencv": + # TODO: Check if opencv keypoints are already in corner convention + keypoints, scores, scales, angles, descriptors = run_opencv_sift( + self.sift, (image_np * 255.0).astype(np.uint8) + ) + pred = { + "keypoints": keypoints, + "scales": scales, + "oris": angles, + "descriptors": descriptors, + } + if scores is not None: + pred["scores"] = scores + + # sometimes pycolmap returns points outside the image. We remove them + if self.conf.backend.startswith("pycolmap"): + is_inside = ( + pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]]) + ).all(-1) + pred = {k: v[is_inside] for k, v in pred.items()} + + if self.conf.nms_radius is not None: + keep = filter_dog_point( + pred["keypoints"], + pred["scales"], + pred["oris"], + image_np.shape, + self.conf.nms_radius, + scores=pred.get("scores"), + ) + pred = {k: v[keep] for k, v in pred.items()} + + pred = {k: torch.from_numpy(v) for k, v in pred.items()} + if scores is not None: + # Keep the k keypoints with highest score + num_points = self.conf.max_keypoints + if num_points is not None and len(pred["keypoints"]) > num_points: + indices = torch.topk(pred["scores"], num_points).indices + pred = {k: v[indices] for k, v in pred.items()} + return pred + + def _forward(self, data: dict) -> dict: + image = data["image"] + if image.shape[1] == 3: + image = rgb_to_grayscale(image) + device = image.device + image = image.cpu() + pred = [] + for k in range(len(image)): + img = image[k] + if "image_size" in data.keys(): + # avoid extracting points in padded areas + w, h = data["image_size"][k] + img = img[:, :h, :w] + p = self.extract_single_image(img) + pred.append(p) + pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} + if self.conf.rootsift: + pred["descriptors"] = sift_to_rootsift(pred["descriptors"]) + pred["descriptors"] = pred["descriptors"].permute(0, 2, 1) + pred["keypoint_scores"] = pred["scores"].clone() + return pred diff --git a/imcui/hloc/extractors/superpoint.py b/imcui/hloc/extractors/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4c03e314743be2b862f3b8d8df078d2f85bc39 --- /dev/null +++ b/imcui/hloc/extractors/superpoint.py @@ -0,0 +1,51 @@ +import sys +from pathlib import Path + +import torch + +from .. import logger + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from SuperGluePretrainedNetwork.models import superpoint # noqa E402 + + +# The original keypoint sampling is incorrect. We patch it here but +# we don't fix it upstream to not impact exisiting evaluations. +def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): + """Interpolate descriptors at keypoint locations""" + b, c, h, w = descriptors.shape + keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s) + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + descriptors = torch.nn.functional.grid_sample( + descriptors, + keypoints.view(b, 1, -1, 2), + mode="bilinear", + align_corners=False, + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + + +class SuperPoint(BaseModel): + default_conf = { + "nms_radius": 4, + "keypoint_threshold": 0.005, + "max_keypoints": -1, + "remove_borders": 4, + "fix_sampling": False, + } + required_inputs = ["image"] + detection_noise = 2.0 + + def _init(self, conf): + if conf["fix_sampling"]: + superpoint.sample_descriptors = sample_descriptors_fix_sampling + self.net = superpoint.SuperPoint(conf) + logger.info("Load SuperPoint model done.") + + def _forward(self, data): + return self.net(data, self.conf) diff --git a/imcui/hloc/extractors/xfeat.py b/imcui/hloc/extractors/xfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..f29e115dca54db10bc2b58369eb1ff28dc6e3b2c --- /dev/null +++ b/imcui/hloc/extractors/xfeat.py @@ -0,0 +1,33 @@ +import torch + +from .. import logger + +from ..utils.base_model import BaseModel + + +class XFeat(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": -1, + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(sparse) model done.") + + def _forward(self, data): + pred = self.net.detectAndCompute( + data["image"], top_k=self.conf["max_keypoints"] + )[0] + pred = { + "keypoints": pred["keypoints"][None], + "scores": pred["scores"][None], + "descriptors": pred["descriptors"].T[None], + } + return pred diff --git a/imcui/hloc/localize_inloc.py b/imcui/hloc/localize_inloc.py new file mode 100644 index 0000000000000000000000000000000000000000..acda7520012c53f468b1603d6a26a34855ebbffb --- /dev/null +++ b/imcui/hloc/localize_inloc.py @@ -0,0 +1,179 @@ +import argparse +import pickle +from pathlib import Path + +import cv2 +import h5py +import numpy as np +import pycolmap +import torch +from scipy.io import loadmat +from tqdm import tqdm + +from . import logger +from .utils.parsers import names_to_pair, parse_retrieval + + +def interpolate_scan(scan, kp): + h, w, c = scan.shape + kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 + assert np.all(kp > -1) and np.all(kp < 1) + scan = torch.from_numpy(scan).permute(2, 0, 1)[None] + kp = torch.from_numpy(kp)[None, None] + grid_sample = torch.nn.functional.grid_sample + + # To maximize the number of points that have depth: + # do bilinear interpolation first and then nearest for the remaining points + interp_lin = grid_sample(scan, kp, align_corners=True, mode="bilinear")[0, :, 0] + interp_nn = torch.nn.functional.grid_sample( + scan, kp, align_corners=True, mode="nearest" + )[0, :, 0] + interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) + valid = ~torch.any(torch.isnan(interp), 0) + + kp3d = interp.T.numpy() + valid = valid.numpy() + return kp3d, valid + + +def get_scan_pose(dataset_dir, rpath): + split_image_rpath = rpath.split("/") + floor_name = split_image_rpath[-3] + scan_id = split_image_rpath[-2] + image_name = split_image_rpath[-1] + building_name = image_name[:3] + + path = Path( + dataset_dir, + "database/alignments", + floor_name, + f"transformations/{building_name}_trans_{scan_id}.txt", + ) + with open(path) as f: + raw_lines = f.readlines() + + P_after_GICP = np.array( + [ + np.fromstring(raw_lines[7], sep=" "), + np.fromstring(raw_lines[8], sep=" "), + np.fromstring(raw_lines[9], sep=" "), + np.fromstring(raw_lines[10], sep=" "), + ] + ) + + return P_after_GICP + + +def pose_from_cluster(dataset_dir, q, retrieved, feature_file, match_file, skip=None): + height, width = cv2.imread(str(dataset_dir / q)).shape[:2] + cx = 0.5 * width + cy = 0.5 * height + focal_length = 4032.0 * 28.0 / 36.0 + + all_mkpq = [] + all_mkpr = [] + all_mkp3d = [] + all_indices = [] + kpq = feature_file[q]["keypoints"].__array__() + num_matches = 0 + + for i, r in enumerate(retrieved): + kpr = feature_file[r]["keypoints"].__array__() + pair = names_to_pair(q, r) + m = match_file[pair]["matches0"].__array__() + v = m > -1 + + if skip and (np.count_nonzero(v) < skip): + continue + + mkpq, mkpr = kpq[v], kpr[m[v]] + num_matches += len(mkpq) + + scan_r = loadmat(Path(dataset_dir, r + ".mat"))["XYZcut"] + mkp3d, valid = interpolate_scan(scan_r, mkpr) + Tr = get_scan_pose(dataset_dir, r) + mkp3d = (Tr[:3, :3] @ mkp3d.T + Tr[:3, -1:]).T + + all_mkpq.append(mkpq[valid]) + all_mkpr.append(mkpr[valid]) + all_mkp3d.append(mkp3d[valid]) + all_indices.append(np.full(np.count_nonzero(valid), i)) + + all_mkpq = np.concatenate(all_mkpq, 0) + all_mkpr = np.concatenate(all_mkpr, 0) + all_mkp3d = np.concatenate(all_mkp3d, 0) + all_indices = np.concatenate(all_indices, 0) + + cfg = { + "model": "SIMPLE_PINHOLE", + "width": width, + "height": height, + "params": [focal_length, cx, cy], + } + ret = pycolmap.absolute_pose_estimation(all_mkpq, all_mkp3d, cfg, 48.00) + ret["cfg"] = cfg + return ret, all_mkpq, all_mkpr, all_mkp3d, all_indices, num_matches + + +def main(dataset_dir, retrieval, features, matches, results, skip_matches=None): + assert retrieval.exists(), retrieval + assert features.exists(), features + assert matches.exists(), matches + + retrieval_dict = parse_retrieval(retrieval) + queries = list(retrieval_dict.keys()) + + feature_file = h5py.File(features, "r", libver="latest") + match_file = h5py.File(matches, "r", libver="latest") + + poses = {} + logs = { + "features": features, + "matches": matches, + "retrieval": retrieval, + "loc": {}, + } + logger.info("Starting localization...") + for q in tqdm(queries): + db = retrieval_dict[q] + ret, mkpq, mkpr, mkp3d, indices, num_matches = pose_from_cluster( + dataset_dir, q, db, feature_file, match_file, skip_matches + ) + + poses[q] = (ret["qvec"], ret["tvec"]) + logs["loc"][q] = { + "db": db, + "PnP_ret": ret, + "keypoints_query": mkpq, + "keypoints_db": mkpr, + "3d_points": mkp3d, + "indices_db": indices, + "num_matches": num_matches, + } + + logger.info(f"Writing poses to {results}...") + with open(results, "w") as f: + for q in queries: + qvec, tvec = poses[q] + qvec = " ".join(map(str, qvec)) + tvec = " ".join(map(str, tvec)) + name = q.split("/")[-1] + f.write(f"{name} {qvec} {tvec}\n") + + logs_path = f"{results}_logs.pkl" + logger.info(f"Writing logs to {logs_path}...") + with open(logs_path, "wb") as f: + pickle.dump(logs, f) + logger.info("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_dir", type=Path, required=True) + parser.add_argument("--retrieval", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + parser.add_argument("--results", type=Path, required=True) + parser.add_argument("--skip_matches", type=int) + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/localize_sfm.py b/imcui/hloc/localize_sfm.py new file mode 100644 index 0000000000000000000000000000000000000000..8122e2aca1aa057c424f0e39204c193f01cd57e7 --- /dev/null +++ b/imcui/hloc/localize_sfm.py @@ -0,0 +1,243 @@ +import argparse +import pickle +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Union + +import numpy as np +import pycolmap +from tqdm import tqdm + +from . import logger +from .utils.io import get_keypoints, get_matches +from .utils.parsers import parse_image_lists, parse_retrieval + + +def do_covisibility_clustering( + frame_ids: List[int], reconstruction: pycolmap.Reconstruction +): + clusters = [] + visited = set() + for frame_id in frame_ids: + # Check if already labeled + if frame_id in visited: + continue + + # New component + clusters.append([]) + queue = {frame_id} + while len(queue): + exploration_frame = queue.pop() + + # Already part of the component + if exploration_frame in visited: + continue + visited.add(exploration_frame) + clusters[-1].append(exploration_frame) + + observed = reconstruction.images[exploration_frame].points2D + connected_frames = { + obs.image_id + for p2D in observed + if p2D.has_point3D() + for obs in reconstruction.points3D[p2D.point3D_id].track.elements + } + connected_frames &= set(frame_ids) + connected_frames -= visited + queue |= connected_frames + + clusters = sorted(clusters, key=len, reverse=True) + return clusters + + +class QueryLocalizer: + def __init__(self, reconstruction, config=None): + self.reconstruction = reconstruction + self.config = config or {} + + def localize(self, points2D_all, points2D_idxs, points3D_id, query_camera): + points2D = points2D_all[points2D_idxs] + points3D = [self.reconstruction.points3D[j].xyz for j in points3D_id] + ret = pycolmap.absolute_pose_estimation( + points2D, + points3D, + query_camera, + estimation_options=self.config.get("estimation", {}), + refinement_options=self.config.get("refinement", {}), + ) + return ret + + +def pose_from_cluster( + localizer: QueryLocalizer, + qname: str, + query_camera: pycolmap.Camera, + db_ids: List[int], + features_path: Path, + matches_path: Path, + **kwargs, +): + kpq = get_keypoints(features_path, qname) + kpq += 0.5 # COLMAP coordinates + + kp_idx_to_3D = defaultdict(list) + kp_idx_to_3D_to_db = defaultdict(lambda: defaultdict(list)) + num_matches = 0 + for i, db_id in enumerate(db_ids): + image = localizer.reconstruction.images[db_id] + if image.num_points3D == 0: + logger.debug(f"No 3D points found for {image.name}.") + continue + points3D_ids = np.array( + [p.point3D_id if p.has_point3D() else -1 for p in image.points2D] + ) + + matches, _ = get_matches(matches_path, qname, image.name) + matches = matches[points3D_ids[matches[:, 1]] != -1] + num_matches += len(matches) + for idx, m in matches: + id_3D = points3D_ids[m] + kp_idx_to_3D_to_db[idx][id_3D].append(i) + # avoid duplicate observations + if id_3D not in kp_idx_to_3D[idx]: + kp_idx_to_3D[idx].append(id_3D) + + idxs = list(kp_idx_to_3D.keys()) + mkp_idxs = [i for i in idxs for _ in kp_idx_to_3D[i]] + mp3d_ids = [j for i in idxs for j in kp_idx_to_3D[i]] + ret = localizer.localize(kpq, mkp_idxs, mp3d_ids, query_camera, **kwargs) + if ret is not None: + ret["camera"] = query_camera + + # mostly for logging and post-processing + mkp_to_3D_to_db = [ + (j, kp_idx_to_3D_to_db[i][j]) for i in idxs for j in kp_idx_to_3D[i] + ] + log = { + "db": db_ids, + "PnP_ret": ret, + "keypoints_query": kpq[mkp_idxs], + "points3D_ids": mp3d_ids, + "points3D_xyz": None, # we don't log xyz anymore because of file size + "num_matches": num_matches, + "keypoint_index_to_db": (mkp_idxs, mkp_to_3D_to_db), + } + return ret, log + + +def main( + reference_sfm: Union[Path, pycolmap.Reconstruction], + queries: Path, + retrieval: Path, + features: Path, + matches: Path, + results: Path, + ransac_thresh: int = 12, + covisibility_clustering: bool = False, + prepend_camera_name: bool = False, + config: Dict = None, +): + assert retrieval.exists(), retrieval + assert features.exists(), features + assert matches.exists(), matches + + queries = parse_image_lists(queries, with_intrinsics=True) + retrieval_dict = parse_retrieval(retrieval) + + logger.info("Reading the 3D model...") + if not isinstance(reference_sfm, pycolmap.Reconstruction): + reference_sfm = pycolmap.Reconstruction(reference_sfm) + db_name_to_id = {img.name: i for i, img in reference_sfm.images.items()} + + config = { + "estimation": {"ransac": {"max_error": ransac_thresh}}, + **(config or {}), + } + localizer = QueryLocalizer(reference_sfm, config) + + cam_from_world = {} + logs = { + "features": features, + "matches": matches, + "retrieval": retrieval, + "loc": {}, + } + logger.info("Starting localization...") + for qname, qcam in tqdm(queries): + if qname not in retrieval_dict: + logger.warning(f"No images retrieved for query image {qname}. Skipping...") + continue + db_names = retrieval_dict[qname] + db_ids = [] + for n in db_names: + if n not in db_name_to_id: + logger.warning(f"Image {n} was retrieved but not in database") + continue + db_ids.append(db_name_to_id[n]) + + if covisibility_clustering: + clusters = do_covisibility_clustering(db_ids, reference_sfm) + best_inliers = 0 + best_cluster = None + logs_clusters = [] + for i, cluster_ids in enumerate(clusters): + ret, log = pose_from_cluster( + localizer, qname, qcam, cluster_ids, features, matches + ) + if ret is not None and ret["num_inliers"] > best_inliers: + best_cluster = i + best_inliers = ret["num_inliers"] + logs_clusters.append(log) + if best_cluster is not None: + ret = logs_clusters[best_cluster]["PnP_ret"] + cam_from_world[qname] = ret["cam_from_world"] + logs["loc"][qname] = { + "db": db_ids, + "best_cluster": best_cluster, + "log_clusters": logs_clusters, + "covisibility_clustering": covisibility_clustering, + } + else: + ret, log = pose_from_cluster( + localizer, qname, qcam, db_ids, features, matches + ) + if ret is not None: + cam_from_world[qname] = ret["cam_from_world"] + else: + closest = reference_sfm.images[db_ids[0]] + cam_from_world[qname] = closest.cam_from_world + log["covisibility_clustering"] = covisibility_clustering + logs["loc"][qname] = log + + logger.info(f"Localized {len(cam_from_world)} / {len(queries)} images.") + logger.info(f"Writing poses to {results}...") + with open(results, "w") as f: + for query, t in cam_from_world.items(): + qvec = " ".join(map(str, t.rotation.quat[[3, 0, 1, 2]])) + tvec = " ".join(map(str, t.translation)) + name = query.split("/")[-1] + if prepend_camera_name: + name = query.split("/")[-2] + "/" + name + f.write(f"{name} {qvec} {tvec}\n") + + logs_path = f"{results}_logs.pkl" + logger.info(f"Writing logs to {logs_path}...") + # TODO: Resolve pickling issue with pycolmap objects. + with open(logs_path, "wb") as f: + pickle.dump(logs, f) + logger.info("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--reference_sfm", type=Path, required=True) + parser.add_argument("--queries", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + parser.add_argument("--retrieval", type=Path, required=True) + parser.add_argument("--results", type=Path, required=True) + parser.add_argument("--ransac_thresh", type=float, default=12.0) + parser.add_argument("--covisibility_clustering", action="store_true") + parser.add_argument("--prepend_camera_name", action="store_true") + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/match_dense.py b/imcui/hloc/match_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..978db58b2438b337549a810196fdf017e0195e85 --- /dev/null +++ b/imcui/hloc/match_dense.py @@ -0,0 +1,1158 @@ +import argparse +import pprint +from collections import Counter, defaultdict +from itertools import chain +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +import cv2 +import h5py +import numpy as np +import torch +import torchvision.transforms.functional as F +from scipy.spatial import KDTree +from tqdm import tqdm + +from . import logger, matchers +from .extract_features import read_image, resize_image +from .match_features import find_unique_new_pairs +from .utils.base_model import dynamic_load +from .utils.io import list_h5_names +from .utils.parsers import names_to_pair, parse_retrieval + +device = "cuda" if torch.cuda.is_available() else "cpu" + +confs = { + # Best quality but loads of points. Only use for small scenes + "loftr": { + "output": "matches-loftr", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + "minima_loftr": { + "output": "matches-minima_loftr", + "model": { + "name": "loftr", + "weights": "outdoor", + "model_name": "minima_loftr.ckpt", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": False, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + "eloftr": { + "output": "matches-eloftr", + "model": { + "name": "eloftr", + "model_name": "eloftr_outdoor.ckpt", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 32, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + "xoftr": { + "output": "matches-xoftr", + "model": { + "name": "xoftr", + "weights": "weights_xoftr_640.ckpt", + "max_keypoints": 2000, + "match_threshold": 0.3, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + # "loftr_quadtree": { + # "output": "matches-loftr-quadtree", + # "model": { + # "name": "quadtree", + # "weights": "outdoor", + # "max_keypoints": 2000, + # "match_threshold": 0.2, + # }, + # "preprocessing": { + # "grayscale": True, + # "resize_max": 1024, + # "dfactor": 8, + # "width": 640, + # "height": 480, + # "force_resize": True, + # }, + # "max_error": 1, # max error for assigned keypoints (in px) + # "cell_size": 1, # size of quantization patch (max 1 kp/patch) + # }, + "cotr": { + "output": "matches-cotr", + "model": { + "name": "cotr", + "weights": "out/default", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, # max error for assigned keypoints (in px) + "cell_size": 1, # size of quantization patch (max 1 kp/patch) + }, + # Semi-scalable loftr which limits detected keypoints + "loftr_aachen": { + "output": "matches-loftr_aachen", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": {"grayscale": True, "resize_max": 1024, "dfactor": 8}, + "max_error": 2, # max error for assigned keypoints (in px) + "cell_size": 8, # size of quantization patch (max 1 kp/patch) + }, + # Use for matching superpoint feats with loftr + "loftr_superpoint": { + "output": "matches-loftr_aachen", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 4, # max error for assigned keypoints (in px) + "cell_size": 4, # size of quantization patch (max 1 kp/patch) + }, + # Use topicfm for matching feats + "topicfm": { + "output": "matches-topicfm", + "model": { + "name": "topicfm", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + }, + }, + # Use aspanformer for matching feats + "aspanformer": { + "output": "matches-aspanformer", + "model": { + "name": "aspanformer", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "duster": { + "output": "matches-duster", + "model": { + "name": "duster", + "weights": "vit_large", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 512, + "dfactor": 16, + }, + }, + "mast3r": { + "output": "matches-mast3r", + "model": { + "name": "mast3r", + "weights": "vit_large", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "resize_max": 512, + "dfactor": 16, + }, + }, + "xfeat_lightglue": { + "output": "matches-xfeat_lightglue", + "model": { + "name": "xfeat_lightglue", + "max_keypoints": 8000, + }, + "preprocessing": { + "grayscale": False, + "force_resize": False, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "xfeat_dense": { + "output": "matches-xfeat_dense", + "model": { + "name": "xfeat_dense", + "max_keypoints": 8000, + }, + "preprocessing": { + "grayscale": False, + "force_resize": False, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "dkm": { + "output": "matches-dkm", + "model": { + "name": "dkm", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 80, + "height": 60, + "dfactor": 8, + }, + }, + "roma": { + "output": "matches-roma", + "model": { + "name": "roma", + "weights": "outdoor", + "model_name": "roma_outdoor.pth", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 320, + "height": 240, + "dfactor": 8, + }, + }, + "minima_roma": { + "output": "matches-minima_roma", + "model": { + "name": "roma", + "weights": "outdoor", + "model_name": "minima_roma.pth", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": False, + "resize_max": 1024, + "width": 320, + "height": 240, + "dfactor": 8, + }, + }, + "gim(dkm)": { + "output": "matches-gim", + "model": { + "name": "gim", + "model_name": "gim_dkm_100h.ckpt", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": False, + "force_resize": True, + "resize_max": 1024, + "width": 320, + "height": 240, + "dfactor": 8, + }, + }, + "omniglue": { + "output": "matches-omniglue", + "model": { + "name": "omniglue", + "match_threshold": 0.2, + "max_keypoints": 2000, + "features": "null", + }, + "preprocessing": { + "grayscale": False, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + "width": 640, + "height": 480, + }, + }, + "sold2": { + "output": "matches-sold2", + "model": { + "name": "sold2", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "gluestick": { + "output": "matches-gluestick", + "model": { + "name": "gluestick", + "use_lines": True, + "max_keypoints": 1000, + "max_lines": 300, + "force_num_keypoints": False, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1024, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, +} + + +def to_cpts(kpts, ps): + if ps > 0.0: + kpts = np.round(np.round((kpts + 0.5) / ps) * ps - 0.5, 2) + return [tuple(cpt) for cpt in kpts] + + +def assign_keypoints( + kpts: np.ndarray, + other_cpts: Union[List[Tuple], np.ndarray], + max_error: float, + update: bool = False, + ref_bins: Optional[List[Counter]] = None, + scores: Optional[np.ndarray] = None, + cell_size: Optional[int] = None, +): + if not update: + # Without update this is just a NN search + if len(other_cpts) == 0 or len(kpts) == 0: + return np.full(len(kpts), -1) + dist, kpt_ids = KDTree(np.array(other_cpts)).query(kpts) + valid = dist <= max_error + kpt_ids[~valid] = -1 + return kpt_ids + else: + ps = cell_size if cell_size is not None else max_error + ps = max(ps, max_error) + # With update we quantize and bin (optionally) + assert isinstance(other_cpts, list) + kpt_ids = [] + cpts = to_cpts(kpts, ps) + bpts = to_cpts(kpts, int(max_error)) + cp_to_id = {val: i for i, val in enumerate(other_cpts)} + for i, (cpt, bpt) in enumerate(zip(cpts, bpts)): + try: + kid = cp_to_id[cpt] + except KeyError: + kid = len(cp_to_id) + cp_to_id[cpt] = kid + other_cpts.append(cpt) + if ref_bins is not None: + ref_bins.append(Counter()) + if ref_bins is not None: + score = scores[i] if scores is not None else 1 + ref_bins[cp_to_id[cpt]][bpt] += score + kpt_ids.append(kid) + return np.array(kpt_ids) + + +def get_grouped_ids(array): + # Group array indices based on its values + # all duplicates are grouped as a set + idx_sort = np.argsort(array) + sorted_array = array[idx_sort] + _, ids, _ = np.unique(sorted_array, return_counts=True, return_index=True) + res = np.split(idx_sort, ids[1:]) + return res + + +def get_unique_matches(match_ids, scores): + if len(match_ids.shape) == 1: + return [0] + + isets1 = get_grouped_ids(match_ids[:, 0]) + isets2 = get_grouped_ids(match_ids[:, 1]) + uid1s = [ids[scores[ids].argmax()] for ids in isets1 if len(ids) > 0] + uid2s = [ids[scores[ids].argmax()] for ids in isets2 if len(ids) > 0] + uids = list(set(uid1s).intersection(uid2s)) + return match_ids[uids], scores[uids] + + +def matches_to_matches0(matches, scores): + if len(matches) == 0: + return np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.float16) + n_kps0 = np.max(matches[:, 0]) + 1 + matches0 = -np.ones((n_kps0,)) + scores0 = np.zeros((n_kps0,)) + matches0[matches[:, 0]] = matches[:, 1] + scores0[matches[:, 0]] = scores + return matches0.astype(np.int32), scores0.astype(np.float16) + + +def kpids_to_matches0(kpt_ids0, kpt_ids1, scores): + valid = (kpt_ids0 != -1) & (kpt_ids1 != -1) + matches = np.dstack([kpt_ids0[valid], kpt_ids1[valid]]) + matches = matches.reshape(-1, 2) + scores = scores[valid] + + # Remove n-to-1 matches + matches, scores = get_unique_matches(matches, scores) + return matches_to_matches0(matches, scores) + + +def scale_keypoints(kpts, scale): + if np.any(scale != 1.0): + kpts *= kpts.new_tensor(scale) + return kpts + + +class ImagePairDataset(torch.utils.data.Dataset): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + } + + def __init__(self, image_dir, conf, pairs): + self.image_dir = image_dir + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.pairs = pairs + if self.conf.cache_images: + image_names = set(sum(pairs, ())) # unique image names in pairs + logger.info(f"Loading and caching {len(image_names)} unique images.") + self.images = {} + self.scales = {} + for name in tqdm(image_names): + image = read_image(self.image_dir / name, self.conf.grayscale) + self.images[name], self.scales[name] = self.preprocess(image) + + def preprocess(self, image: np.ndarray): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + + if self.conf.resize_max: + scale = self.conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + + if self.conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // self.conf.dfactor * self.conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + if self.conf.cache_images: + image0, scale0 = self.images[name0], self.scales[name0] + image1, scale1 = self.images[name1], self.scales[name1] + else: + image0 = read_image(self.image_dir / name0, self.conf.grayscale) + image1 = read_image(self.image_dir / name1, self.conf.grayscale) + image0, scale0 = self.preprocess(image0) + image1, scale1 = self.preprocess(image1) + return image0, image1, scale0, scale1, name0, name1 + + +@torch.no_grad() +def match_dense( + conf: Dict, + pairs: List[Tuple[str, str]], + image_dir: Path, + match_path: Path, # out + existing_refs: Optional[List] = [], +): + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(matchers, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + dataset = ImagePairDataset(image_dir, conf["preprocessing"], pairs) + loader = torch.utils.data.DataLoader( + dataset, num_workers=16, batch_size=1, shuffle=False + ) + + logger.info("Performing dense matching...") + with h5py.File(str(match_path), "a") as fd: + for data in tqdm(loader, smoothing=0.1): + # load image-pair data + image0, image1, scale0, scale1, (name0,), (name1,) = data + scale0, scale1 = scale0[0].numpy(), scale1[0].numpy() + image0, image1 = image0.to(device), image1.to(device) + + # match semi-dense + # for consistency with pairs_from_*: refine kpts of image0 + if name0 in existing_refs: + # special case: flip to enable refinement in query image + pred = model({"image0": image1, "image1": image0}) + pred = { + **pred, + "keypoints0": pred["keypoints1"], + "keypoints1": pred["keypoints0"], + } + else: + # usual case + pred = model({"image0": image0, "image1": image1}) + + # Rescale keypoints and move to cpu + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5 + kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5 + kpts0 = kpts0.cpu().numpy() + kpts1 = kpts1.cpu().numpy() + scores = pred["scores"].cpu().numpy() + + # Write matches and matching scores in hloc format + pair = names_to_pair(name0, name1) + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + + # Write dense matching output + grp.create_dataset("keypoints0", data=kpts0) + grp.create_dataset("keypoints1", data=kpts1) + grp.create_dataset("scores", data=scores) + del model, loader + + +# default: quantize all! +def load_keypoints( + conf: Dict, feature_paths_refs: List[Path], quantize: Optional[set] = None +): + name2ref = { + n: i for i, p in enumerate(feature_paths_refs) for n in list_h5_names(p) + } + + existing_refs = set(name2ref.keys()) + if quantize is None: + quantize = existing_refs # quantize all + if len(existing_refs) > 0: + logger.info(f"Loading keypoints from {len(existing_refs)} images.") + + # Load query keypoints + cpdict = defaultdict(list) + bindict = defaultdict(list) + for name in existing_refs: + with h5py.File(str(feature_paths_refs[name2ref[name]]), "r") as fd: + kps = fd[name]["keypoints"].__array__() + if name not in quantize: + cpdict[name] = kps + else: + if "scores" in fd[name].keys(): + kp_scores = fd[name]["scores"].__array__() + else: + # we set the score to 1.0 if not provided + # increase for more weight on reference keypoints for + # stronger anchoring + kp_scores = [1.0 for _ in range(kps.shape[0])] + # bin existing keypoints of reference images for association + assign_keypoints( + kps, + cpdict[name], + conf["max_error"], + True, + bindict[name], + kp_scores, + conf["cell_size"], + ) + return cpdict, bindict + + +def aggregate_matches( + conf: Dict, + pairs: List[Tuple[str, str]], + match_path: Path, + feature_path: Path, + required_queries: Optional[Set[str]] = None, + max_kps: Optional[int] = None, + cpdict: Dict[str, Iterable] = defaultdict(list), + bindict: Dict[str, List[Counter]] = defaultdict(list), +): + if required_queries is None: + required_queries = set(sum(pairs, ())) + # default: do not overwrite existing features in feature_path! + required_queries -= set(list_h5_names(feature_path)) + + # if an entry in cpdict is provided as np.ndarray we assume it is fixed + required_queries -= set([k for k, v in cpdict.items() if isinstance(v, np.ndarray)]) + + # sort pairs for reduced RAM + pairs_per_q = Counter(list(chain(*pairs))) + pairs_score = [min(pairs_per_q[i], pairs_per_q[j]) for i, j in pairs] + pairs = [p for _, p in sorted(zip(pairs_score, pairs))] + + if len(required_queries) > 0: + logger.info(f"Aggregating keypoints for {len(required_queries)} images.") + n_kps = 0 + with h5py.File(str(match_path), "a") as fd: + for name0, name1 in tqdm(pairs, smoothing=0.1): + pair = names_to_pair(name0, name1) + grp = fd[pair] + kpts0 = grp["keypoints0"].__array__() + kpts1 = grp["keypoints1"].__array__() + scores = grp["scores"].__array__() + + # Aggregate local features + update0 = name0 in required_queries + update1 = name1 in required_queries + + # in localization we do not want to bin the query kp + # assumes that the query is name0! + if update0 and not update1 and max_kps is None: + max_error0 = cell_size0 = 0.0 + else: + max_error0 = conf["max_error"] + cell_size0 = conf["cell_size"] + + # Get match ids and extend query keypoints (cpdict) + mkp_ids0 = assign_keypoints( + kpts0, + cpdict[name0], + max_error0, + update0, + bindict[name0], + scores, + cell_size0, + ) + mkp_ids1 = assign_keypoints( + kpts1, + cpdict[name1], + conf["max_error"], + update1, + bindict[name1], + scores, + conf["cell_size"], + ) + + # Build matches from assignments + matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores) + + assert kpts0.shape[0] == scores.shape[0] + grp.create_dataset("matches0", data=matches0) + grp.create_dataset("matching_scores0", data=scores0) + + # Convert bins to kps if finished, and store them + for name in (name0, name1): + pairs_per_q[name] -= 1 + if pairs_per_q[name] > 0 or name not in required_queries: + continue + kp_score = [c.most_common(1)[0][1] for c in bindict[name]] + cpdict[name] = [c.most_common(1)[0][0] for c in bindict[name]] + cpdict[name] = np.array(cpdict[name], dtype=np.float32) + + # Select top-k query kps by score (reassign matches later) + if max_kps: + top_k = min(max_kps, cpdict[name].shape[0]) + top_k = np.argsort(kp_score)[::-1][:top_k] + cpdict[name] = cpdict[name][top_k] + kp_score = np.array(kp_score)[top_k] + + # Write query keypoints + with h5py.File(feature_path, "a") as kfd: + if name in kfd: + del kfd[name] + kgrp = kfd.create_group(name) + kgrp.create_dataset("keypoints", data=cpdict[name]) + kgrp.create_dataset("score", data=kp_score) + n_kps += cpdict[name].shape[0] + del bindict[name] + + if len(required_queries) > 0: + avg_kp_per_image = round(n_kps / len(required_queries), 1) + logger.info( + f"Finished assignment, found {avg_kp_per_image} " + f"keypoints/image (avg.), total {n_kps}." + ) + return cpdict + + +def assign_matches( + pairs: List[Tuple[str, str]], + match_path: Path, + keypoints: Union[List[Path], Dict[str, np.array]], + max_error: float, +): + if isinstance(keypoints, list): + keypoints = load_keypoints({}, keypoints, kpts_as_bin=set([])) + assert len(set(sum(pairs, ())) - set(keypoints.keys())) == 0 + with h5py.File(str(match_path), "a") as fd: + for name0, name1 in tqdm(pairs): + pair = names_to_pair(name0, name1) + grp = fd[pair] + kpts0 = grp["keypoints0"].__array__() + kpts1 = grp["keypoints1"].__array__() + scores = grp["scores"].__array__() + + # NN search across cell boundaries + mkp_ids0 = assign_keypoints(kpts0, keypoints[name0], max_error) + mkp_ids1 = assign_keypoints(kpts1, keypoints[name1], max_error) + + matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores) + + # overwrite matches0 and matching_scores0 + del grp["matches0"], grp["matching_scores0"] + grp.create_dataset("matches0", data=matches0) + grp.create_dataset("matching_scores0", data=scores0) + + +@torch.no_grad() +def match_and_assign( + conf: Dict, + pairs_path: Path, + image_dir: Path, + match_path: Path, # out + feature_path_q: Path, # out + feature_paths_refs: Optional[List[Path]] = [], + max_kps: Optional[int] = 8192, + overwrite: bool = False, +) -> Path: + for path in feature_paths_refs: + if not path.exists(): + raise FileNotFoundError(f"Reference feature file {path}.") + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + required_queries = set(sum(pairs, ())) + + name2ref = { + n: i for i, p in enumerate(feature_paths_refs) for n in list_h5_names(p) + } + existing_refs = required_queries.intersection(set(name2ref.keys())) + + # images which require feature extraction + required_queries = required_queries - existing_refs + + if feature_path_q.exists(): + existing_queries = set(list_h5_names(feature_path_q)) + feature_paths_refs.append(feature_path_q) + existing_refs = set.union(existing_refs, existing_queries) + if not overwrite: + required_queries = required_queries - existing_queries + + if len(pairs) == 0 and len(required_queries) == 0: + logger.info("All pairs exist. Skipping dense matching.") + return + + # extract semi-dense matches + match_dense(conf, pairs, image_dir, match_path, existing_refs=existing_refs) + + logger.info("Assigning matches...") + + # Pre-load existing keypoints + cpdict, bindict = load_keypoints( + conf, feature_paths_refs, quantize=required_queries + ) + + # Reassign matches by aggregation + cpdict = aggregate_matches( + conf, + pairs, + match_path, + feature_path=feature_path_q, + required_queries=required_queries, + max_kps=max_kps, + cpdict=cpdict, + bindict=bindict, + ) + + # Invalidate matches that are far from selected bin by reassignment + if max_kps is not None: + logger.info(f'Reassign matches with max_error={conf["max_error"]}.') + assign_matches(pairs, match_path, cpdict, max_error=conf["max_error"]) + + +def scale_lines(lines, scale): + if np.any(scale != 1.0): + lines *= lines.new_tensor(scale) + return lines + + +def match(model, path_0, path_1, conf): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + "force_resize": False, + "width": 320, + "height": 240, + } + + def preprocess(image: np.ndarray): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + if conf.resize_max: + scale = conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + if conf.force_resize: + size = image.shape[:2][::-1] + image = resize_image(image, (conf.width, conf.height), "cv2_area") + size_new = (conf.width, conf.height) + scale = np.array(size) / np.array(size_new) + if conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // conf.dfactor * conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new, antialias=True) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + conf = SimpleNamespace(**{**default_conf, **conf}) + image0 = read_image(path_0, conf.grayscale) + image1 = read_image(path_1, conf.grayscale) + image0, scale0 = preprocess(image0) + image1, scale1 = preprocess(image1) + image0 = image0.to(device)[None] + image1 = image1.to(device)[None] + pred = model({"image0": image0, "image1": image1}) + + # Rescale keypoints and move to cpu + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5 + kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5 + + ret = { + "image0": image0.squeeze().cpu().numpy(), + "image1": image1.squeeze().cpu().numpy(), + "keypoints0": kpts0.cpu().numpy(), + "keypoints1": kpts1.cpu().numpy(), + } + if "mconf" in pred.keys(): + ret["mconf"] = pred["mconf"].cpu().numpy() + return ret + + +@torch.no_grad() +def match_images(model, image_0, image_1, conf, device="cpu"): + default_conf = { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "cache_images": False, + "force_resize": False, + "width": 320, + "height": 240, + } + + def preprocess(image: np.ndarray): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + if conf.resize_max: + scale = conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x * scale)) for x in size) + image = resize_image(image, size_new, "cv2_area") + scale = np.array(size) / np.array(size_new) + if conf.force_resize: + size = image.shape[:2][::-1] + image = resize_image(image, (conf.width, conf.height), "cv2_area") + size_new = (conf.width, conf.height) + scale = np.array(size) / np.array(size_new) + if conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple( + map( + lambda x: int(x // conf.dfactor * conf.dfactor), + image.shape[-2:], + ) + ) + image = F.resize(image, size=size_new) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + conf = SimpleNamespace(**{**default_conf, **conf}) + + if len(image_0.shape) == 3 and conf.grayscale: + image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY) + else: + image0 = image_0 + if len(image_0.shape) == 3 and conf.grayscale: + image1 = cv2.cvtColor(image_1, cv2.COLOR_RGB2GRAY) + else: + image1 = image_1 + + # comment following lines, image is always RGB mode + # if not conf.grayscale and len(image0.shape) == 3: + # image0 = image0[:, :, ::-1] # BGR to RGB + # if not conf.grayscale and len(image1.shape) == 3: + # image1 = image1[:, :, ::-1] # BGR to RGB + + image0, scale0 = preprocess(image0) + image1, scale1 = preprocess(image1) + image0 = image0.to(device)[None] + image1 = image1.to(device)[None] + pred = model({"image0": image0, "image1": image1}) + + s0 = np.array(image_0.shape[:2][::-1]) / np.array(image0.shape[-2:][::-1]) + s1 = np.array(image_1.shape[:2][::-1]) / np.array(image1.shape[-2:][::-1]) + + # Rescale keypoints and move to cpu + if "keypoints0" in pred.keys() and "keypoints1" in pred.keys(): + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5 + kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5 + + ret = { + "image0": image0.squeeze().cpu().numpy(), + "image1": image1.squeeze().cpu().numpy(), + "image0_orig": image_0, + "image1_orig": image_1, + "keypoints0": kpts0.cpu().numpy(), + "keypoints1": kpts1.cpu().numpy(), + "keypoints0_orig": kpts0_origin.cpu().numpy(), + "keypoints1_orig": kpts1_origin.cpu().numpy(), + "mkeypoints0": kpts0.cpu().numpy(), + "mkeypoints1": kpts1.cpu().numpy(), + "mkeypoints0_orig": kpts0_origin.cpu().numpy(), + "mkeypoints1_orig": kpts1_origin.cpu().numpy(), + "original_size0": np.array(image_0.shape[:2][::-1]), + "original_size1": np.array(image_1.shape[:2][::-1]), + "new_size0": np.array(image0.shape[-2:][::-1]), + "new_size1": np.array(image1.shape[-2:][::-1]), + "scale0": s0, + "scale1": s1, + } + if "mconf" in pred.keys(): + ret["mconf"] = pred["mconf"].cpu().numpy() + elif "scores" in pred.keys(): # adapting loftr + ret["mconf"] = pred["scores"].cpu().numpy() + else: + ret["mconf"] = np.ones_like(kpts0.cpu().numpy()[:, 0]) + if "lines0" in pred.keys() and "lines1" in pred.keys(): + if "keypoints0" in pred.keys() and "keypoints1" in pred.keys(): + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5 + kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5 + kpts0_origin = kpts0_origin.cpu().numpy() + kpts1_origin = kpts1_origin.cpu().numpy() + else: + kpts0_origin, kpts1_origin = ( + None, + None, + ) # np.zeros([0]), np.zeros([0]) + lines0, lines1 = pred["lines0"], pred["lines1"] + lines0_raw, lines1_raw = pred["raw_lines0"], pred["raw_lines1"] + + lines0_raw = torch.from_numpy(lines0_raw.copy()) + lines1_raw = torch.from_numpy(lines1_raw.copy()) + lines0_raw = scale_lines(lines0_raw + 0.5, s0) - 0.5 + lines1_raw = scale_lines(lines1_raw + 0.5, s1) - 0.5 + + lines0 = torch.from_numpy(lines0.copy()) + lines1 = torch.from_numpy(lines1.copy()) + lines0 = scale_lines(lines0 + 0.5, s0) - 0.5 + lines1 = scale_lines(lines1 + 0.5, s1) - 0.5 + + ret = { + "image0_orig": image_0, + "image1_orig": image_1, + "line0": lines0_raw.cpu().numpy(), + "line1": lines1_raw.cpu().numpy(), + "line0_orig": lines0.cpu().numpy(), + "line1_orig": lines1.cpu().numpy(), + "line_keypoints0_orig": kpts0_origin, + "line_keypoints1_orig": kpts1_origin, + } + del pred + torch.cuda.empty_cache() + return ret + + +@torch.no_grad() +def main( + conf: Dict, + pairs: Path, + image_dir: Path, + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, # out + features: Optional[Path] = None, # out + features_ref: Optional[Path] = None, + max_kps: Optional[int] = 8192, + overwrite: bool = False, +) -> Path: + logger.info( + "Extracting semi-dense features with configuration:" f"\n{pprint.pformat(conf)}" + ) + + if features is None: + features = "feats_" + + if isinstance(features, Path): + features_q = features + if matches is None: + raise ValueError( + "Either provide both features and matches as Path" " or both as names." + ) + else: + if export_dir is None: + raise ValueError( + "Provide an export_dir if features and matches" + f" are not file paths: {features}, {matches}." + ) + features_q = Path(export_dir, f'{features}{conf["output"]}.h5') + if matches is None: + matches = Path(export_dir, f'{conf["output"]}_{pairs.stem}.h5') + + if features_ref is None: + features_ref = [] + elif isinstance(features_ref, list): + features_ref = list(features_ref) + elif isinstance(features_ref, Path): + features_ref = [features_ref] + else: + raise TypeError(str(features_ref)) + + match_and_assign( + conf, + pairs, + image_dir, + matches, + features_q, + features_ref, + max_kps, + overwrite, + ) + + return features_q, matches + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + parser.add_argument("--export_dir", type=Path, required=True) + parser.add_argument("--matches", type=Path, default=confs["loftr"]["output"]) + parser.add_argument( + "--features", type=str, default="feats_" + confs["loftr"]["output"] + ) + parser.add_argument("--conf", type=str, default="loftr", choices=list(confs.keys())) + args = parser.parse_args() + main( + confs[args.conf], + args.pairs, + args.image_dir, + args.export_dir, + args.matches, + args.features, + ) diff --git a/imcui/hloc/match_features.py b/imcui/hloc/match_features.py new file mode 100644 index 0000000000000000000000000000000000000000..50917b5a586a172a092a6d00c5ec9235e1b81a36 --- /dev/null +++ b/imcui/hloc/match_features.py @@ -0,0 +1,459 @@ +import argparse +import pprint +from functools import partial +from pathlib import Path +from queue import Queue +from threading import Thread +from typing import Dict, List, Optional, Tuple, Union + +import h5py +import numpy as np +import torch +from tqdm import tqdm + +from . import logger, matchers +from .utils.base_model import dynamic_load +from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval + +""" +A set of standard configurations that can be directly selected from the command +line using their name. Each is a dictionary with the following entries: + - output: the name of the match file that will be generated. + - model: the model configuration, as passed to a feature matcher. +""" +confs = { + "superglue": { + "output": "matches-superglue", + "model": { + "name": "superglue", + "weights": "outdoor", + "sinkhorn_iterations": 50, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "superglue-fast": { + "output": "matches-superglue-it5", + "model": { + "name": "superglue", + "weights": "outdoor", + "sinkhorn_iterations": 5, + "match_threshold": 0.2, + }, + }, + "superpoint-lightglue": { + "output": "matches-lightglue", + "model": { + "name": "lightglue", + "match_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "superpoint", + "model_name": "superpoint_lightglue.pth", + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "disk-lightglue": { + "output": "matches-disk-lightglue", + "model": { + "name": "lightglue", + "match_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "disk", + "model_name": "disk_lightglue.pth", + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "aliked-lightglue": { + "output": "matches-aliked-lightglue", + "model": { + "name": "lightglue", + "match_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "aliked", + "model_name": "aliked_lightglue.pth", + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "sift-lightglue": { + "output": "matches-sift-lightglue", + "model": { + "name": "lightglue", + "match_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "sift", + "add_scale_ori": True, + "model_name": "sift_lightglue.pth", + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "sgmnet": { + "output": "matches-sgmnet", + "model": { + "name": "sgmnet", + "seed_top_k": [256, 256], + "seed_radius_coe": 0.01, + "net_channels": 128, + "layer_num": 9, + "head": 4, + "seedlayer": [0, 6], + "use_mc_seeding": True, + "use_score_encoding": False, + "conf_bar": [1.11, 0.1], + "sink_iter": [10, 100], + "detach_iter": 1000000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "force_resize": False, + }, + }, + "NN-superpoint": { + "output": "matches-NN-mutual-dist.7", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "distance_threshold": 0.7, + "match_threshold": 0.2, + }, + }, + "NN-ratio": { + "output": "matches-NN-mutual-ratio.8", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "ratio_threshold": 0.8, + "match_threshold": 0.2, + }, + }, + "NN-mutual": { + "output": "matches-NN-mutual", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "match_threshold": 0.2, + }, + }, + "Dual-Softmax": { + "output": "matches-Dual-Softmax", + "model": { + "name": "dual_softmax", + "match_threshold": 0.01, + "inv_temperature": 20, + }, + }, + "adalam": { + "output": "matches-adalam", + "model": { + "name": "adalam", + "match_threshold": 0.2, + }, + }, + "imp": { + "output": "matches-imp", + "model": { + "name": "imp", + "match_threshold": 0.2, + }, + }, +} + + +class WorkQueue: + def __init__(self, work_fn, num_threads=1): + self.queue = Queue(num_threads) + self.threads = [ + Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads) + ] + for thread in self.threads: + thread.start() + + def join(self): + for thread in self.threads: + self.queue.put(None) + for thread in self.threads: + thread.join() + + def thread_fn(self, work_fn): + item = self.queue.get() + while item is not None: + work_fn(item) + item = self.queue.get() + + def put(self, data): + self.queue.put(data) + + +class FeaturePairsDataset(torch.utils.data.Dataset): + def __init__(self, pairs, feature_path_q, feature_path_r): + self.pairs = pairs + self.feature_path_q = feature_path_q + self.feature_path_r = feature_path_r + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + data = {} + with h5py.File(self.feature_path_q, "r") as fd: + grp = fd[name0] + for k, v in grp.items(): + data[k + "0"] = torch.from_numpy(v.__array__()).float() + # some matchers might expect an image but only use its size + data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + with h5py.File(self.feature_path_r, "r") as fd: + grp = fd[name1] + for k, v in grp.items(): + data[k + "1"] = torch.from_numpy(v.__array__()).float() + data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + return data + + def __len__(self): + return len(self.pairs) + + +def writer_fn(inp, match_path): + pair, pred = inp + with h5py.File(str(match_path), "a", libver="latest") as fd: + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + matches = pred["matches0"][0].cpu().short().numpy() + grp.create_dataset("matches0", data=matches) + if "matching_scores0" in pred: + scores = pred["matching_scores0"][0].cpu().half().numpy() + grp.create_dataset("matching_scores0", data=scores) + + +def main( + conf: Dict, + pairs: Path, + features: Union[Path, str], + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, + features_ref: Optional[Path] = None, + overwrite: bool = False, +) -> Path: + if isinstance(features, Path) or Path(features).exists(): + features_q = features + if matches is None: + raise ValueError( + "Either provide both features and matches as Path" " or both as names." + ) + else: + if export_dir is None: + raise ValueError( + "Provide an export_dir if features is not" f" a file path: {features}." + ) + features_q = Path(export_dir, features + ".h5") + if matches is None: + matches = Path(export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5') + + if features_ref is None: + features_ref = features_q + match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite) + + return matches + + +def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): + """Avoid to recompute duplicates to save time.""" + pairs = set() + for i, j in pairs_all: + if (j, i) not in pairs: + pairs.add((i, j)) + pairs = list(pairs) + if match_path is not None and match_path.exists(): + with h5py.File(str(match_path), "r", libver="latest") as fd: + pairs_filtered = [] + for i, j in pairs: + if ( + names_to_pair(i, j) in fd + or names_to_pair(j, i) in fd + or names_to_pair_old(i, j) in fd + or names_to_pair_old(j, i) in fd + ): + continue + pairs_filtered.append((i, j)) + return pairs_filtered + return pairs + + +@torch.no_grad() +def match_from_paths( + conf: Dict, + pairs_path: Path, + match_path: Path, + feature_path_q: Path, + feature_path_ref: Path, + overwrite: bool = False, +) -> Path: + logger.info( + "Matching local features with configuration:" f"\n{pprint.pformat(conf)}" + ) + + if not feature_path_q.exists(): + raise FileNotFoundError(f"Query feature file {feature_path_q}.") + if not feature_path_ref.exists(): + raise FileNotFoundError(f"Reference feature file {feature_path_ref}.") + match_path.parent.mkdir(exist_ok=True, parents=True) + + assert pairs_path.exists(), pairs_path + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + if len(pairs) == 0: + logger.info("Skipping the matching.") + return + + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(matchers, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) + loader = torch.utils.data.DataLoader( + dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True + ) + writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) + + for idx, data in enumerate(tqdm(loader, smoothing=0.1)): + data = { + k: v if k.startswith("image") else v.to(device, non_blocking=True) + for k, v in data.items() + } + pred = model(data) + pair = names_to_pair(*pairs[idx]) + writer_queue.put((pair, pred)) + writer_queue.join() + logger.info("Finished exporting matches.") + + +def scale_keypoints(kpts, scale): + if ( + isinstance(scale, (list, tuple, np.ndarray)) + and len(scale) == 2 + and np.any(scale != np.array([1.0, 1.0])) + ): + if isinstance(kpts, torch.Tensor): + kpts[:, 0] *= scale[0] # scale x-dimension + kpts[:, 1] *= scale[1] # scale y-dimension + elif isinstance(kpts, np.ndarray): + kpts[:, 0] *= scale[0] # scale x-dimension + kpts[:, 1] *= scale[1] # scale y-dimension + return kpts + + +@torch.no_grad() +def match_images(model, feat0, feat1): + # forward pass to match keypoints + desc0 = feat0["descriptors"][0] + desc1 = feat1["descriptors"][0] + if len(desc0.shape) == 2: + desc0 = desc0.unsqueeze(0) + if len(desc1.shape) == 2: + desc1 = desc1.unsqueeze(0) + if isinstance(feat0["keypoints"], list): + feat0["keypoints"] = feat0["keypoints"][0][None] + if isinstance(feat1["keypoints"], list): + feat1["keypoints"] = feat1["keypoints"][0][None] + input_dict = { + "image0": feat0["image"], + "keypoints0": feat0["keypoints"], + "scores0": feat0["scores"][0].unsqueeze(0), + "descriptors0": desc0, + "image1": feat1["image"], + "keypoints1": feat1["keypoints"], + "scores1": feat1["scores"][0].unsqueeze(0), + "descriptors1": desc1, + } + if "scales" in feat0: + input_dict = {**input_dict, "scales0": feat0["scales"]} + if "scales" in feat1: + input_dict = {**input_dict, "scales1": feat1["scales"]} + if "oris" in feat0: + input_dict = {**input_dict, "oris0": feat0["oris"]} + if "oris" in feat1: + input_dict = {**input_dict, "oris1": feat1["oris"]} + pred = model(input_dict) + pred = { + k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v + for k, v in pred.items() + } + kpts0, kpts1 = ( + feat0["keypoints"][0].cpu().numpy(), + feat1["keypoints"][0].cpu().numpy(), + ) + matches, confid = pred["matches0"], pred["matching_scores0"] + # Keep the matching keypoints. + valid = matches > -1 + mkpts0 = kpts0[valid] + mkpts1 = kpts1[matches[valid]] + mconfid = confid[valid] + # rescale the keypoints to their original size + s0 = feat0["original_size"] / feat0["size"] + s1 = feat1["original_size"] / feat1["size"] + kpts0_origin = scale_keypoints(torch.from_numpy(kpts0 + 0.5), s0) - 0.5 + kpts1_origin = scale_keypoints(torch.from_numpy(kpts1 + 0.5), s1) - 0.5 + + mkpts0_origin = scale_keypoints(torch.from_numpy(mkpts0 + 0.5), s0) - 0.5 + mkpts1_origin = scale_keypoints(torch.from_numpy(mkpts1 + 0.5), s1) - 0.5 + + ret = { + "image0_orig": feat0["image_orig"], + "image1_orig": feat1["image_orig"], + "keypoints0": kpts0, + "keypoints1": kpts1, + "keypoints0_orig": kpts0_origin.numpy(), + "keypoints1_orig": kpts1_origin.numpy(), + "mkeypoints0": mkpts0, + "mkeypoints1": mkpts1, + "mkeypoints0_orig": mkpts0_origin.numpy(), + "mkeypoints1_orig": mkpts1_origin.numpy(), + "mconf": mconfid.numpy(), + } + del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin + torch.cuda.empty_cache() + + return ret + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--export_dir", type=Path) + parser.add_argument("--features", type=str, default="feats-superpoint-n4096-r1024") + parser.add_argument("--matches", type=Path) + parser.add_argument( + "--conf", type=str, default="superglue", choices=list(confs.keys()) + ) + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir) diff --git a/imcui/hloc/matchers/__init__.py b/imcui/hloc/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fd381eb391604db3d3c3c03d278c0e08de2531 --- /dev/null +++ b/imcui/hloc/matchers/__init__.py @@ -0,0 +1,3 @@ +def get_matcher(matcher): + mod = __import__(f"{__name__}.{matcher}", fromlist=[""]) + return getattr(mod, "Model") diff --git a/imcui/hloc/matchers/adalam.py b/imcui/hloc/matchers/adalam.py new file mode 100644 index 0000000000000000000000000000000000000000..7820428a5a087d0b5d6855de15c0230327ce7dc1 --- /dev/null +++ b/imcui/hloc/matchers/adalam.py @@ -0,0 +1,68 @@ +import torch +from kornia.feature.adalam import AdalamFilter +from kornia.utils.helpers import get_cuda_device_if_available + +from ..utils.base_model import BaseModel + + +class AdaLAM(BaseModel): + # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html. + default_conf = { + "area_ratio": 100, + "search_expansion": 4, + "ransac_iters": 128, + "min_inliers": 6, + "min_confidence": 200, + "orientation_difference_threshold": 30, + "scale_rate_threshold": 1.5, + "detected_scale_rate_threshold": 5, + "refit": True, + "force_seed_mnn": True, + "device": get_cuda_device_if_available(), + } + required_inputs = [ + "image0", + "image1", + "descriptors0", + "descriptors1", + "keypoints0", + "keypoints1", + "scales0", + "scales1", + "oris0", + "oris1", + ] + + def _init(self, conf): + self.adalam = AdalamFilter(conf) + + def _forward(self, data): + assert data["keypoints0"].size(0) == 1 + if data["keypoints0"].size(1) < 2 or data["keypoints1"].size(1) < 2: + matches = torch.zeros( + (0, 2), dtype=torch.int64, device=data["keypoints0"].device + ) + else: + matches = self.adalam.match_and_filter( + data["keypoints0"][0], + data["keypoints1"][0], + data["descriptors0"][0].T, + data["descriptors1"][0].T, + data["image0"].shape[2:], + data["image1"].shape[2:], + data["oris0"][0], + data["oris1"][0], + data["scales0"][0], + data["scales1"][0], + ) + matches_new = torch.full( + (data["keypoints0"].size(1),), + -1, + dtype=torch.int64, + device=data["keypoints0"].device, + ) + matches_new[matches[:, 0]] = matches[:, 1] + return { + "matches0": matches_new.unsqueeze(0), + "matching_scores0": torch.zeros(matches_new.size(0)).unsqueeze(0), + } diff --git a/imcui/hloc/matchers/aspanformer.py b/imcui/hloc/matchers/aspanformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0636ff95a60edc992a8ded22590a7ab1baad4210 --- /dev/null +++ b/imcui/hloc/matchers/aspanformer.py @@ -0,0 +1,66 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer +from ASpanFormer.src.config.default import get_cfg_defaults +from ASpanFormer.src.utils.misc import lower_config + +aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer" + + +class ASpanFormer(BaseModel): + default_conf = { + "model_name": "outdoor.ckpt", + "match_threshold": 0.2, + "sinkhorn_iterations": 20, + "max_keypoints": 2048, + "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py", + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + config = get_cfg_defaults() + config.merge_from_file(conf["config_path"]) + _config = lower_config(config) + + # update: match threshold + _config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"] + _config["aspan"]["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] + + self.net = _ASpanFormer(config=_config["aspan"]) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + state_dict = torch.load(str(model_path), map_location="cpu")["state_dict"] + self.net.load_state_dict(state_dict, strict=False) + logger.info("Loaded Aspanformer model") + + def _forward(self, data): + data_ = { + "image0": data["image0"], + "image1": data["image1"], + } + self.net(data_, online_resize=True) + pred = { + "keypoints0": data_["mkpts0_f"], + "keypoints1": data_["mkpts1_f"], + "mconf": data_["mconf"], + } + scores = data_["mconf"] + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + scores = scores[keep] + pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + scores, + ) + return pred diff --git a/imcui/hloc/matchers/cotr.py b/imcui/hloc/matchers/cotr.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec0234b2917ad0e3da9fbff76da9bcf83a19c04 --- /dev/null +++ b/imcui/hloc/matchers/cotr.py @@ -0,0 +1,77 @@ +import argparse +import sys +from pathlib import Path + +import numpy as np +import torch +from torchvision.transforms import ToPILImage + +from .. import DEVICE, MODEL_REPO_ID + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party/COTR")) +from COTR.inference.sparse_engine import SparseEngine +from COTR.models import build_model +from COTR.options.options import * # noqa: F403 +from COTR.options.options_utils import * # noqa: F403 +from COTR.utils import utils as utils_cotr + +utils_cotr.fix_randomness(0) +torch.set_grad_enabled(False) + + +class COTR(BaseModel): + default_conf = { + "weights": "out/default", + "match_threshold": 0.2, + "max_keypoints": -1, + "model_name": "checkpoint.pth.tar", + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) # noqa: F405 + opt = parser.parse_args() + opt.command = " ".join(sys.argv) + opt.load_weights_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + + layer_2_channels = { + "layer1": 256, + "layer2": 512, + "layer3": 1024, + "layer4": 2048, + } + opt.dim_feedforward = layer_2_channels[opt.layer] + + model = build_model(opt) + model = model.to(DEVICE) + weights = torch.load(opt.load_weights_path, map_location="cpu")[ + "model_state_dict" + ] + utils_cotr.safe_load_weights(model, weights) + self.net = model.eval() + self.to_pil_func = ToPILImage(mode="RGB") + + def _forward(self, data): + img_a = np.array(self.to_pil_func(data["image0"][0].cpu())) + img_b = np.array(self.to_pil_func(data["image1"][0].cpu())) + corrs = SparseEngine( + self.net, 32, mode="tile" + ).cotr_corr_multiscale_with_cycle_consistency( + img_a, + img_b, + np.linspace(0.5, 0.0625, 4), + 1, + max_corrs=self.conf["max_keypoints"], + queries_a=None, + ) + pred = { + "keypoints0": torch.from_numpy(corrs[:, :2]), + "keypoints1": torch.from_numpy(corrs[:, 2:]), + } + return pred diff --git a/imcui/hloc/matchers/dkm.py b/imcui/hloc/matchers/dkm.py new file mode 100644 index 0000000000000000000000000000000000000000..2deca95ca987dbd4d7e1fbb5c65e587222d3dd4c --- /dev/null +++ b/imcui/hloc/matchers/dkm.py @@ -0,0 +1,53 @@ +import sys +from pathlib import Path + +from PIL import Image + +from .. import DEVICE, MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from DKM.dkm import DKMv3_outdoor + + +class DKMv3(BaseModel): + default_conf = { + "model_name": "DKMv3_outdoor.pth", + "match_threshold": 0.2, + "max_keypoints": -1, + } + required_inputs = [ + "image0", + "image1", + ] + + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + + self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=DEVICE) + logger.info("Loading DKMv3 model done") + + def _forward(self, data): + img0 = data["image0"].cpu().numpy().squeeze() * 255 + img1 = data["image1"].cpu().numpy().squeeze() * 255 + img0 = img0.transpose(1, 2, 0) + img1 = img1.transpose(1, 2, 0) + img0 = Image.fromarray(img0.astype("uint8")) + img1 = Image.fromarray(img1.astype("uint8")) + W_A, H_A = img0.size + W_B, H_B = img1.size + + warp, certainty = self.net.match(img0, img1, device=DEVICE) + matches, certainty = self.net.sample( + warp, certainty, num=self.conf["max_keypoints"] + ) + kpts1, kpts2 = self.net.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + pred = { + "keypoints0": kpts1, + "keypoints1": kpts2, + "mconf": certainty, + } + return pred diff --git a/imcui/hloc/matchers/dual_softmax.py b/imcui/hloc/matchers/dual_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..1cef54473a0483ce2bca3b158c507c9e9480b641 --- /dev/null +++ b/imcui/hloc/matchers/dual_softmax.py @@ -0,0 +1,71 @@ +import numpy as np +import torch + +from ..utils.base_model import BaseModel + + +# borrow from dedode +def dual_softmax_matcher( + desc_A: tuple["B", "C", "N"], # noqa: F821 + desc_B: tuple["B", "C", "M"], # noqa: F821 + threshold=0.1, + inv_temperature=20, + normalize=True, +): + B, C, N = desc_A.shape + if len(desc_A.shape) < 3: + desc_A, desc_B = desc_A[None], desc_B[None] + if normalize: + desc_A = desc_A / desc_A.norm(dim=1, keepdim=True) + desc_B = desc_B / desc_B.norm(dim=1, keepdim=True) + sim = torch.einsum("b c n, b c m -> b n m", desc_A, desc_B) * inv_temperature + P = sim.softmax(dim=-2) * sim.softmax(dim=-1) + mask = torch.nonzero( + (P == P.max(dim=-1, keepdim=True).values) + * (P == P.max(dim=-2, keepdim=True).values) + * (P > threshold) + ) + mask = mask.cpu().numpy() + matches0 = np.ones((B, P.shape[-2]), dtype=int) * (-1) + scores0 = np.zeros((B, P.shape[-2]), dtype=float) + matches0[:, mask[:, 1]] = mask[:, 2] + tmp_P = P.cpu().numpy() + scores0[:, mask[:, 1]] = tmp_P[mask[:, 0], mask[:, 1], mask[:, 2]] + matches0 = torch.from_numpy(matches0).to(P.device) + scores0 = torch.from_numpy(scores0).to(P.device) + return matches0, scores0 + + +class DualSoftMax(BaseModel): + default_conf = { + "match_threshold": 0.2, + "inv_temperature": 20, + } + # shape: B x DIM x M + required_inputs = ["descriptors0", "descriptors1"] + + def _init(self, conf): + pass + + def _forward(self, data): + if data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0: + matches0 = torch.full( + data["descriptors0"].shape[:2], + -1, + device=data["descriptors0"].device, + ) + return { + "matches0": matches0, + "matching_scores0": torch.zeros_like(matches0), + } + + matches0, scores0 = dual_softmax_matcher( + data["descriptors0"], + data["descriptors1"], + threshold=self.conf["match_threshold"], + inv_temperature=self.conf["inv_temperature"], + ) + return { + "matches0": matches0, # 1 x M + "matching_scores0": scores0, + } diff --git a/imcui/hloc/matchers/duster.py b/imcui/hloc/matchers/duster.py new file mode 100644 index 0000000000000000000000000000000000000000..36fa34bc6d433295800e0223db9dc97fec93f9f9 --- /dev/null +++ b/imcui/hloc/matchers/duster.py @@ -0,0 +1,109 @@ +import sys +from pathlib import Path + +import numpy as np +import torch +import torchvision.transforms as tfm + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +duster_path = Path(__file__).parent / "../../third_party/dust3r" +sys.path.append(str(duster_path)) + +from dust3r.cloud_opt import GlobalAlignerMode, global_aligner +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.utils.geometry import find_reciprocal_matches, xy_grid + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Duster(BaseModel): + default_conf = { + "name": "Duster3r", + "model_name": "duster_vit_large.pth", + "max_keypoints": 3000, + "vit_patch_size": 16, + } + + def _init(self, conf): + self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + self.net = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(device) + logger.info("Loaded Dust3r model") + + def preprocess(self, img): + # the super-class already makes sure that img0,img1 have + # same resolution and that h == w + _, h, _ = img.shape + imsize = h + if not ((h % self.vit_patch_size) == 0): + imsize = int(self.vit_patch_size * round(h / self.vit_patch_size, 0)) + img = tfm.functional.resize(img, imsize, antialias=True) + + _, new_h, new_w = img.shape + if not ((new_w % self.vit_patch_size) == 0): + safe_w = int(self.vit_patch_size * round(new_w / self.vit_patch_size, 0)) + img = tfm.functional.resize(img, (new_h, safe_w), antialias=True) + + img = self.normalize(img).unsqueeze(0) + + return img + + def _forward(self, data): + img0, img1 = data["image0"], data["image1"] + mean = torch.tensor([0.5, 0.5, 0.5]).to(device) + std = torch.tensor([0.5, 0.5, 0.5]).to(device) + + img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + + images = [ + {"img": img0, "idx": 0, "instance": 0}, + {"img": img1, "idx": 1, "instance": 1}, + ] + pairs = make_pairs( + images, scene_graph="complete", prefilter=None, symmetrize=True + ) + output = inference(pairs, self.net, device, batch_size=1) + scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PairViewer) + # retrieve useful values from scene: + imgs = scene.imgs + confidence_masks = scene.get_masks() + pts3d = scene.get_pts3d() + pts2d_list, pts3d_list = [], [] + for i in range(2): + conf_i = confidence_masks[i].cpu().numpy() + pts2d_list.append( + xy_grid(*imgs[i].shape[:2][::-1])[conf_i] + ) # imgs[i].shape[:2] = (H, W) + pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) + + if len(pts3d_list[1]) == 0: + pred = { + "keypoints0": torch.zeros([0, 2]), + "keypoints1": torch.zeros([0, 2]), + } + logger.warning(f"Matched {0} points") + else: + reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches( + *pts3d_list + ) + logger.info(f"Found {num_matches} matches") + mkpts1 = pts2d_list[1][reciprocal_in_P2] + mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2] + top_k = self.conf["max_keypoints"] + if top_k is not None and len(mkpts0) > top_k: + keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int) + mkpts0 = mkpts0[keep] + mkpts1 = mkpts1[keep] + pred = { + "keypoints0": torch.from_numpy(mkpts0), + "keypoints1": torch.from_numpy(mkpts1), + } + return pred diff --git a/imcui/hloc/matchers/eloftr.py b/imcui/hloc/matchers/eloftr.py new file mode 100644 index 0000000000000000000000000000000000000000..7ca352808e7b5a2a8bc7253be2d591c439798491 --- /dev/null +++ b/imcui/hloc/matchers/eloftr.py @@ -0,0 +1,97 @@ +import sys +import warnings +from copy import deepcopy +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger + +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) + +from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_ +from EfficientLoFTR.src.loftr import ( + full_default_cfg, + opt_default_cfg, + reparameter, +) + + +from ..utils.base_model import BaseModel + + +class ELoFTR(BaseModel): + default_conf = { + "model_name": "eloftr_outdoor.ckpt", + "match_threshold": 0.2, + # "sinkhorn_iterations": 20, + "max_keypoints": -1, + # You can choose model type in ['full', 'opt'] + "model_type": "full", # 'full' for best quality, 'opt' for best efficiency + # You can choose numerical precision in ['fp32', 'mp', 'fp16']. 'fp16' for best efficiency + "precision": "fp32", + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + if self.conf["model_type"] == "full": + _default_cfg = deepcopy(full_default_cfg) + elif self.conf["model_type"] == "opt": + _default_cfg = deepcopy(opt_default_cfg) + + if self.conf["precision"] == "mp": + _default_cfg["mp"] = True + elif self.conf["precision"] == "fp16": + _default_cfg["half"] = True + + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + + cfg = _default_cfg + cfg["match_coarse"]["thr"] = conf["match_threshold"] + # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + matcher = ELoFTR_(config=cfg) + matcher.load_state_dict(state_dict) + self.net = reparameter(matcher) + + if self.conf["precision"] == "fp16": + self.net = self.net.half() + logger.info(f"Loaded Efficient LoFTR with weights {conf['model_name']}") + + def _forward(self, data): + # For consistency with hloc pairs, we refine kpts in image0! + rename = { + "keypoints0": "keypoints1", + "keypoints1": "keypoints0", + "image0": "image1", + "image1": "image0", + "mask0": "mask1", + "mask1": "mask0", + } + data_ = {rename[k]: v for k, v in data.items()} + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pred = self.net(data_) + pred = { + "keypoints0": data_["mkpts0_f"], + "keypoints1": data_["mkpts1_f"], + } + scores = data_["mconf"] + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + pred["keypoints0"], pred["keypoints1"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + ) + scores = scores[keep] + + # Switch back indices + pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} + pred["scores"] = scores + return pred diff --git a/imcui/hloc/matchers/gim.py b/imcui/hloc/matchers/gim.py new file mode 100644 index 0000000000000000000000000000000000000000..afaf78c2ac832c47d8d0f7210e8188e8a0aa9899 --- /dev/null +++ b/imcui/hloc/matchers/gim.py @@ -0,0 +1,200 @@ +import sys +from pathlib import Path + +import torch + +from .. import DEVICE, MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +gim_path = Path(__file__).parents[2] / "third_party/gim" +sys.path.append(str(gim_path)) + + +def load_model(weight_name, checkpoints_path): + # load model + model = None + detector = None + if weight_name == "gim_dkm": + from networks.dkm.models.model_zoo.DKMv3 import DKMv3 + + model = DKMv3(weights=None, h=672, w=896) + elif weight_name == "gim_loftr": + from networks.loftr.config import get_cfg_defaults + from networks.loftr.loftr import LoFTR + from networks.loftr.misc import lower_config + + model = LoFTR(lower_config(get_cfg_defaults())["loftr"]) + elif weight_name == "gim_lightglue": + from networks.lightglue.models.matchers.lightglue import LightGlue + from networks.lightglue.superpoint import SuperPoint + + detector = SuperPoint( + { + "max_num_keypoints": 2048, + "force_num_keypoints": True, + "detection_threshold": 0.0, + "nms_radius": 3, + "trainable": False, + } + ) + model = LightGlue( + { + "filter_threshold": 0.1, + "flash": False, + "checkpointed": True, + } + ) + + # load state dict + if weight_name == "gim_dkm": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("model."): + state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) + if "encoder.net.fc" in k: + state_dict.pop(k) + model.load_state_dict(state_dict) + + elif weight_name == "gim_loftr": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict) + + elif weight_name == "gim_lightglue": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("model."): + state_dict.pop(k) + if k.startswith("superpoint."): + state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k) + detector.load_state_dict(state_dict) + + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("superpoint."): + state_dict.pop(k) + if k.startswith("model."): + state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) + model.load_state_dict(state_dict) + + # eval mode + if detector is not None: + detector = detector.eval().to(DEVICE) + model = model.eval().to(DEVICE) + return model + + +class GIM(BaseModel): + default_conf = { + "match_threshold": 0.2, + "checkpoint_dir": gim_path / "weights", + "weights": "gim_dkm", + } + required_inputs = [ + "image0", + "image1", + ] + ckpt_name_dict = { + "gim_dkm": "gim_dkm_100h.ckpt", + "gim_loftr": "gim_loftr_50h.ckpt", + "gim_lightglue": "gim_lightglue_100h.ckpt", + } + + def _init(self, conf): + ckpt_name = self.ckpt_name_dict[conf["weights"]] + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, ckpt_name), + ) + self.aspect_ratio = 896 / 672 + model = load_model(conf["weights"], model_path) + self.net = model + logger.info("Loaded GIM model") + + def pad_image(self, image, aspect_ratio): + new_width = max(image.shape[3], int(image.shape[2] * aspect_ratio)) + new_height = max(image.shape[2], int(image.shape[3] / aspect_ratio)) + pad_width = new_width - image.shape[3] + pad_height = new_height - image.shape[2] + return torch.nn.functional.pad( + image, + ( + pad_width // 2, + pad_width - pad_width // 2, + pad_height // 2, + pad_height - pad_height // 2, + ), + ) + + def rescale_kpts(self, sparse_matches, shape0, shape1): + kpts0 = torch.stack( + ( + shape0[1] * (sparse_matches[:, 0] + 1) / 2, + shape0[0] * (sparse_matches[:, 1] + 1) / 2, + ), + dim=-1, + ) + kpts1 = torch.stack( + ( + shape1[1] * (sparse_matches[:, 2] + 1) / 2, + shape1[0] * (sparse_matches[:, 3] + 1) / 2, + ), + dim=-1, + ) + return kpts0, kpts1 + + def compute_mask(self, kpts0, kpts1, orig_shape0, orig_shape1): + mask = ( + (kpts0[:, 0] > 0) + & (kpts0[:, 1] > 0) + & (kpts1[:, 0] > 0) + & (kpts1[:, 1] > 0) + ) + mask &= ( + (kpts0[:, 0] <= (orig_shape0[1] - 1)) + & (kpts1[:, 0] <= (orig_shape1[1] - 1)) + & (kpts0[:, 1] <= (orig_shape0[0] - 1)) + & (kpts1[:, 1] <= (orig_shape1[0] - 1)) + ) + return mask + + def _forward(self, data): + # TODO: only support dkm+gim + image0, image1 = ( + self.pad_image(data["image0"], self.aspect_ratio), + self.pad_image(data["image1"], self.aspect_ratio), + ) + dense_matches, dense_certainty = self.net.match(image0, image1) + sparse_matches, mconf = self.net.sample( + dense_matches, dense_certainty, self.conf["max_keypoints"] + ) + kpts0, kpts1 = self.rescale_kpts( + sparse_matches, image0.shape[-2:], image1.shape[-2:] + ) + mask = self.compute_mask( + kpts0, kpts1, data["image0"].shape[-2:], data["image1"].shape[-2:] + ) + b_ids, i_ids = torch.where(mconf[None]) + pred = { + "keypoints0": kpts0[i_ids], + "keypoints1": kpts1[i_ids], + "confidence": mconf[i_ids], + "batch_indexes": b_ids, + } + scores, b_ids = pred["confidence"], pred["batch_indexes"] + kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"] + pred["confidence"], pred["batch_indexes"] = scores[mask], b_ids[mask] + pred["keypoints0"], pred["keypoints1"] = kpts0[mask], kpts1[mask] + + out = { + "keypoints0": pred["keypoints0"], + "keypoints1": pred["keypoints1"], + } + return out diff --git a/imcui/hloc/matchers/gluestick.py b/imcui/hloc/matchers/gluestick.py new file mode 100644 index 0000000000000000000000000000000000000000..9f775325fde3e39570ab93a7071455a5a2661dda --- /dev/null +++ b/imcui/hloc/matchers/gluestick.py @@ -0,0 +1,99 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +gluestick_path = Path(__file__).parent / "../../third_party/GlueStick" +sys.path.append(str(gluestick_path)) + +from gluestick import batch_to_np +from gluestick.models.two_view_pipeline import TwoViewPipeline + + +class GlueStick(BaseModel): + default_conf = { + "name": "two_view_pipeline", + "model_name": "checkpoint_GlueStick_MD.tar", + "use_lines": True, + "max_keypoints": 1000, + "max_lines": 300, + "force_num_keypoints": False, + } + required_inputs = [ + "image0", + "image1", + ] + + # Initialize the line matcher + def _init(self, conf): + # Download the model. + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + logger.info("Loading GlueStick model...") + + gluestick_conf = { + "name": "two_view_pipeline", + "use_lines": True, + "extractor": { + "name": "wireframe", + "sp_params": { + "force_num_keypoints": False, + "max_num_keypoints": 1000, + }, + "wireframe_params": { + "merge_points": True, + "merge_line_endpoints": True, + }, + "max_n_lines": 300, + }, + "matcher": { + "name": "gluestick", + "weights": str(model_path), + "trainable": False, + }, + "ground_truth": { + "from_pose_depth": False, + }, + } + gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[ + "max_keypoints" + ] + gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[ + "force_num_keypoints" + ] + gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"] + self.net = TwoViewPipeline(gluestick_conf) + + def _forward(self, data): + pred = self.net(data) + + pred = batch_to_np(pred) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] + m0 = pred["matches0"] + + line_seg0, line_seg1 = pred["lines0"], pred["lines1"] + line_matches = pred["line_matches0"] + + valid_matches = m0 != -1 + match_indices = m0[valid_matches] + matched_kps0 = kp0[valid_matches] + matched_kps1 = kp1[match_indices] + + valid_matches = line_matches != -1 + match_indices = line_matches[valid_matches] + matched_lines0 = line_seg0[valid_matches] + matched_lines1 = line_seg1[match_indices] + + pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1 + pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1 + pred["keypoints0"], pred["keypoints1"] = ( + torch.from_numpy(matched_kps0), + torch.from_numpy(matched_kps1), + ) + pred = {**pred, **data} + return pred diff --git a/imcui/hloc/matchers/imp.py b/imcui/hloc/matchers/imp.py new file mode 100644 index 0000000000000000000000000000000000000000..f37d218e7a8e46ea31b834307f48e4e7649f87a0 --- /dev/null +++ b/imcui/hloc/matchers/imp.py @@ -0,0 +1,50 @@ +import sys +from pathlib import Path + +import torch + +from .. import DEVICE, MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) +from pram.nets.gml import GML + + +class IMP(BaseModel): + default_conf = { + "match_threshold": 0.2, + "features": "sfd2", + "model_name": "imp_gml.920.pth", + "sinkhorn_iterations": 20, + } + required_inputs = [ + "image0", + "keypoints0", + "scores0", + "descriptors0", + "image1", + "keypoints1", + "scores1", + "descriptors1", + ] + + def _init(self, conf): + self.conf = {**self.default_conf, **conf} + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format("pram", self.conf["model_name"]), + ) + + # self.net = nets.gml(self.conf).eval().to(DEVICE) + self.net = GML(self.conf).eval().to(DEVICE) + self.net.load_state_dict( + torch.load(model_path, map_location="cpu")["model"], strict=True + ) + logger.info("Load IMP model done.") + + def _forward(self, data): + data["descriptors0"] = data["descriptors0"].transpose(2, 1).float() + data["descriptors1"] = data["descriptors1"].transpose(2, 1).float() + + return self.net.produce_matches(data, p=0.2) diff --git a/imcui/hloc/matchers/lightglue.py b/imcui/hloc/matchers/lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..39bd3693813d70545bbfbfc24c4b578e10092759 --- /dev/null +++ b/imcui/hloc/matchers/lightglue.py @@ -0,0 +1,67 @@ +import sys +from pathlib import Path + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +lightglue_path = Path(__file__).parent / "../../third_party/LightGlue" +sys.path.append(str(lightglue_path)) +from lightglue import LightGlue as LG + + +class LightGlue(BaseModel): + default_conf = { + "match_threshold": 0.2, + "filter_threshold": 0.2, + "width_confidence": 0.99, # for point pruning + "depth_confidence": 0.95, # for early stopping, + "features": "superpoint", + "model_name": "superpoint_lightglue.pth", + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "add_scale_ori": False, + } + required_inputs = [ + "image0", + "keypoints0", + "scores0", + "descriptors0", + "image1", + "keypoints1", + "scores1", + "descriptors1", + ] + + def _init(self, conf): + logger.info("Loading lightglue model, {}".format(conf["model_name"])) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + conf["weights"] = str(model_path) + conf["filter_threshold"] = conf["match_threshold"] + self.net = LG(**conf) + logger.info("Load lightglue model done.") + + def _forward(self, data): + input = {} + input["image0"] = { + "image": data["image0"], + "keypoints": data["keypoints0"], + "descriptors": data["descriptors0"].permute(0, 2, 1), + } + if "scales0" in data: + input["image0"] = {**input["image0"], "scales": data["scales0"]} + if "oris0" in data: + input["image0"] = {**input["image0"], "oris": data["oris0"]} + + input["image1"] = { + "image": data["image1"], + "keypoints": data["keypoints1"], + "descriptors": data["descriptors1"].permute(0, 2, 1), + } + if "scales1" in data: + input["image1"] = {**input["image1"], "scales": data["scales1"]} + if "oris1" in data: + input["image1"] = {**input["image1"], "oris": data["oris1"]} + return self.net(input) diff --git a/imcui/hloc/matchers/loftr.py b/imcui/hloc/matchers/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7fe28d87337905069c675452bf0bf2522068c7 --- /dev/null +++ b/imcui/hloc/matchers/loftr.py @@ -0,0 +1,71 @@ +import warnings + +import torch +from kornia.feature import LoFTR as LoFTR_ +from kornia.feature.loftr.loftr import default_cfg +from pathlib import Path +from .. import logger, MODEL_REPO_ID + +from ..utils.base_model import BaseModel + + +class LoFTR(BaseModel): + default_conf = { + "weights": "outdoor", + "match_threshold": 0.2, + "sinkhorn_iterations": 20, + "max_keypoints": -1, + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + cfg = default_cfg + cfg["match_coarse"]["thr"] = conf["match_threshold"] + cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"] + + model_name = conf.get("model_name", None) + if model_name is not None and "minima" in model_name: + cfg["coarse"]["temp_bug_fix"] = True + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + self.net = LoFTR_(pretrained=conf["weights"], config=cfg) + self.net.load_state_dict(state_dict) + logger.info(f"ReLoaded LoFTR(minima) with weights {conf['model_name']}") + else: + self.net = LoFTR_(pretrained=conf["weights"], config=cfg) + logger.info(f"Loaded LoFTR with weights {conf['weights']}") + + def _forward(self, data): + # For consistency with hloc pairs, we refine kpts in image0! + rename = { + "keypoints0": "keypoints1", + "keypoints1": "keypoints0", + "image0": "image1", + "image1": "image0", + "mask0": "mask1", + "mask1": "mask0", + } + data_ = {rename[k]: v for k, v in data.items()} + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pred = self.net(data_) + + scores = pred["confidence"] + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + pred["keypoints0"], pred["keypoints1"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + ) + scores = scores[keep] + + # Switch back indices + pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} + pred["scores"] = scores + del pred["confidence"] + return pred diff --git a/imcui/hloc/matchers/mast3r.py b/imcui/hloc/matchers/mast3r.py new file mode 100644 index 0000000000000000000000000000000000000000..47a5a3ffdd6855332f61012ef97037a9f6fe469e --- /dev/null +++ b/imcui/hloc/matchers/mast3r.py @@ -0,0 +1,96 @@ +import sys +from pathlib import Path + +import numpy as np +import torch +import torchvision.transforms as tfm + +from .. import DEVICE, MODEL_REPO_ID, logger + +mast3r_path = Path(__file__).parent / "../../third_party/mast3r" +sys.path.append(str(mast3r_path)) + +dust3r_path = Path(__file__).parent / "../../third_party/dust3r" +sys.path.append(str(dust3r_path)) + +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference +from mast3r.fast_nn import fast_reciprocal_NNs +from mast3r.model import AsymmetricMASt3R + +from .duster import Duster + + +class Mast3r(Duster): + default_conf = { + "name": "Mast3r", + "model_name": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", + "max_keypoints": 2000, + "vit_patch_size": 16, + } + + def _init(self, conf): + self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + self.net = AsymmetricMASt3R.from_pretrained(model_path).to(DEVICE) + logger.info("Loaded Mast3r model") + + def _forward(self, data): + img0, img1 = data["image0"], data["image1"] + mean = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE) + std = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE) + + img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) + + images = [ + {"img": img0, "idx": 0, "instance": 0}, + {"img": img1, "idx": 1, "instance": 1}, + ] + pairs = make_pairs( + images, scene_graph="complete", prefilter=None, symmetrize=True + ) + output = inference(pairs, self.net, DEVICE, batch_size=1) + + # at this stage, you have the raw dust3r predictions + _, pred1 = output["view1"], output["pred1"] + _, pred2 = output["view2"], output["pred2"] + + desc1, desc2 = ( + pred1["desc"][1].squeeze(0).detach(), + pred2["desc"][1].squeeze(0).detach(), + ) + + # find 2D-2D matches between the two images + matches_im0, matches_im1 = fast_reciprocal_NNs( + desc1, + desc2, + subsample_or_initxy1=2, + device=DEVICE, + dist="dot", + block_size=2**13, + ) + + mkpts0 = matches_im0.copy() + mkpts1 = matches_im1.copy() + + if len(mkpts0) == 0: + pred = { + "keypoints0": torch.zeros([0, 2]), + "keypoints1": torch.zeros([0, 2]), + } + logger.warning(f"Matched {0} points") + else: + top_k = self.conf["max_keypoints"] + if top_k is not None and len(mkpts0) > top_k: + keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int) + mkpts0 = mkpts0[keep] + mkpts1 = mkpts1[keep] + pred = { + "keypoints0": torch.from_numpy(mkpts0), + "keypoints1": torch.from_numpy(mkpts1), + } + return pred diff --git a/imcui/hloc/matchers/mickey.py b/imcui/hloc/matchers/mickey.py new file mode 100644 index 0000000000000000000000000000000000000000..d18e908ee64b01ab394b8533b1e5257791424f4e --- /dev/null +++ b/imcui/hloc/matchers/mickey.py @@ -0,0 +1,50 @@ +import sys +from pathlib import Path + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +mickey_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(mickey_path)) + +from mickey.config.default import cfg +from mickey.lib.models.builder import build_model + + +class Mickey(BaseModel): + default_conf = { + "config_path": "config.yaml", + "model_name": "mickey.ckpt", + "max_keypoints": 3000, + } + required_inputs = [ + "image0", + "image1", + ] + + # Initialize the line matcher + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + # TODO: config path of mickey + config_path = model_path.parent / self.conf["config_path"] + logger.info("Loading mickey model...") + cfg.merge_from_file(config_path) + self.net = build_model(cfg, checkpoint=model_path) + logger.info("Load Mickey model done.") + + def _forward(self, data): + pred = self.net(data) + pred = { + **pred, + **data, + } + inliers = data["inliers_list"] + pred = { + "keypoints0": inliers[:, :2], + "keypoints1": inliers[:, 2:4], + } + + return pred diff --git a/imcui/hloc/matchers/nearest_neighbor.py b/imcui/hloc/matchers/nearest_neighbor.py new file mode 100644 index 0000000000000000000000000000000000000000..fab96e780a5c1a1672cdaf5b624ecdb310db23d3 --- /dev/null +++ b/imcui/hloc/matchers/nearest_neighbor.py @@ -0,0 +1,66 @@ +import torch + +from ..utils.base_model import BaseModel + + +def find_nn(sim, ratio_thresh, distance_thresh): + sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True) + dist_nn = 2 * (1 - sim_nn) + mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device) + if ratio_thresh: + mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1]) + if distance_thresh: + mask = mask & (dist_nn[..., 0] <= distance_thresh**2) + matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1)) + scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0)) + return matches, scores + + +def mutual_check(m0, m1): + inds0 = torch.arange(m0.shape[-1], device=m0.device) + loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0))) + ok = (m0 > -1) & (inds0 == loop) + m0_new = torch.where(ok, m0, m0.new_tensor(-1)) + return m0_new + + +class NearestNeighbor(BaseModel): + default_conf = { + "ratio_threshold": None, + "distance_threshold": None, + "do_mutual_check": True, + } + required_inputs = ["descriptors0", "descriptors1"] + + def _init(self, conf): + pass + + def _forward(self, data): + if data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0: + matches0 = torch.full( + data["descriptors0"].shape[:2], + -1, + device=data["descriptors0"].device, + ) + return { + "matches0": matches0, + "matching_scores0": torch.zeros_like(matches0), + } + ratio_threshold = self.conf["ratio_threshold"] + if data["descriptors0"].size(-1) == 1 or data["descriptors1"].size(-1) == 1: + ratio_threshold = None + sim = torch.einsum("bdn,bdm->bnm", data["descriptors0"], data["descriptors1"]) + matches0, scores0 = find_nn( + sim, ratio_threshold, self.conf["distance_threshold"] + ) + if self.conf["do_mutual_check"]: + matches1, scores1 = find_nn( + sim.transpose(1, 2), + ratio_threshold, + self.conf["distance_threshold"], + ) + matches0 = mutual_check(matches0, matches1) + return { + "matches0": matches0, + "matching_scores0": scores0, + } diff --git a/imcui/hloc/matchers/omniglue.py b/imcui/hloc/matchers/omniglue.py new file mode 100644 index 0000000000000000000000000000000000000000..07539535ff61ca9a3bdc075926995d2319a70fee --- /dev/null +++ b/imcui/hloc/matchers/omniglue.py @@ -0,0 +1,80 @@ +import sys +from pathlib import Path + +import numpy as np +import torch + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +thirdparty_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(thirdparty_path)) +from omniglue.src import omniglue + +omniglue_path = thirdparty_path / "omniglue" + + +class OmniGlue(BaseModel): + default_conf = { + "match_threshold": 0.02, + "max_keypoints": 2048, + } + required_inputs = ["image0", "image1"] + dino_v2_link_dict = { + "dinov2_vitb14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" + } + + def _init(self, conf): + logger.info("Loading OmniGlue model") + og_model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, "omniglue.onnx"), + ) + sp_model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, "sp_v6.onnx"), + ) + dino_model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, "dinov2_vitb14_pretrain.pth"), + ) + + self.net = omniglue.OmniGlue( + og_export=str(og_model_path), + sp_export=str(sp_model_path), + dino_export=str(dino_model_path), + max_keypoints=self.conf["max_keypoints"], + ) + logger.info("Loaded OmniGlue model done!") + + def _forward(self, data): + image0_rgb_np = data["image0"][0].permute(1, 2, 0).cpu().numpy() * 255 + image1_rgb_np = data["image1"][0].permute(1, 2, 0).cpu().numpy() * 255 + image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255 + image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255 + match_kp0, match_kp1, match_confidences = self.net.FindMatches( + image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"] + ) + # filter matches + match_threshold = self.conf["match_threshold"] + keep_idx = [] + for i in range(match_kp0.shape[0]): + if match_confidences[i] > match_threshold: + keep_idx.append(i) + scores = torch.from_numpy(match_confidences[keep_idx]).reshape(-1, 1) + pred = { + "keypoints0": torch.from_numpy(match_kp0[keep_idx]), + "keypoints1": torch.from_numpy(match_kp1[keep_idx]), + "mconf": scores, + } + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + scores = scores[keep] + pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + scores, + ) + return pred diff --git a/imcui/hloc/matchers/roma.py b/imcui/hloc/matchers/roma.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb83e1c22a8d22d85fa39a989ed4a8897989354 --- /dev/null +++ b/imcui/hloc/matchers/roma.py @@ -0,0 +1,80 @@ +import sys +from pathlib import Path + +import torch +from PIL import Image + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +roma_path = Path(__file__).parent / "../../third_party/RoMa" +sys.path.append(str(roma_path)) +from romatch.models.model_zoo import roma_model + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Roma(BaseModel): + default_conf = { + "name": "two_view_pipeline", + "model_name": "roma_outdoor.pth", + "model_utils_name": "dinov2_vitl14_pretrain.pth", + "max_keypoints": 3000, + } + required_inputs = [ + "image0", + "image1", + ] + + # Initialize the line matcher + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + + dinov2_weights = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_utils_name"]), + ) + + logger.info("Loading Roma model") + # load the model + weights = torch.load(model_path, map_location="cpu") + dinov2_weights = torch.load(dinov2_weights, map_location="cpu") + + self.net = roma_model( + resolution=(14 * 8 * 6, 14 * 8 * 6), + upsample_preds=False, + weights=weights, + dinov2_weights=dinov2_weights, + device=device, + # temp fix issue: https://github.com/Parskatt/RoMa/issues/26 + amp_dtype=torch.float32, + ) + logger.info("Load Roma model done.") + + def _forward(self, data): + img0 = data["image0"].cpu().numpy().squeeze() * 255 + img1 = data["image1"].cpu().numpy().squeeze() * 255 + img0 = img0.transpose(1, 2, 0) + img1 = img1.transpose(1, 2, 0) + img0 = Image.fromarray(img0.astype("uint8")) + img1 = Image.fromarray(img1.astype("uint8")) + W_A, H_A = img0.size + W_B, H_B = img1.size + + # Match + warp, certainty = self.net.match(img0, img1, device=device) + # Sample matches for estimation + matches, certainty = self.net.sample( + warp, certainty, num=self.conf["max_keypoints"] + ) + kpts1, kpts2 = self.net.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + pred = { + "keypoints0": kpts1, + "keypoints1": kpts2, + "mconf": certainty, + } + + return pred diff --git a/imcui/hloc/matchers/sgmnet.py b/imcui/hloc/matchers/sgmnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b1141e4be0b74a5dc74f4cf1b5189ef4893a8cef --- /dev/null +++ b/imcui/hloc/matchers/sgmnet.py @@ -0,0 +1,106 @@ +import sys +from collections import OrderedDict, namedtuple +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet" +sys.path.append(str(sgmnet_path)) + +from sgmnet import matcher as SGM_Model + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class SGMNet(BaseModel): + default_conf = { + "name": "SGM", + "model_name": "weights/sgm/root/model_best.pth", + "seed_top_k": [256, 256], + "seed_radius_coe": 0.01, + "net_channels": 128, + "layer_num": 9, + "head": 4, + "seedlayer": [0, 6], + "use_mc_seeding": True, + "use_score_encoding": False, + "conf_bar": [1.11, 0.1], + "sink_iter": [10, 100], + "detach_iter": 1000000, + "match_threshold": 0.2, + } + required_inputs = [ + "image0", + "image1", + ] + + # Initialize the line matcher + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + + # config + config = namedtuple("config", conf.keys())(*conf.values()) + self.net = SGM_Model(config) + checkpoint = torch.load(model_path, map_location="cpu") + # for ddp model + if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module": + new_stat_dict = OrderedDict() + for key, value in checkpoint["state_dict"].items(): + new_stat_dict[key[7:]] = value + checkpoint["state_dict"] = new_stat_dict + self.net.load_state_dict(checkpoint["state_dict"]) + logger.info("Load SGMNet model done.") + + def _forward(self, data): + x1 = data["keypoints0"].squeeze() # N x 2 + x2 = data["keypoints1"].squeeze() + score1 = data["scores0"].reshape(-1, 1) # N x 1 + score2 = data["scores1"].reshape(-1, 1) + desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128 + desc2 = data["descriptors1"].permute(0, 2, 1) + size1 = ( + torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device) + ) # W x H -> x & y + size2 = torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device) # W x H + norm_x1 = self.normalize_size(x1, size1) + norm_x2 = self.normalize_size(x2, size2) + + x1 = torch.cat((norm_x1, score1), dim=-1) # N x 3 + x2 = torch.cat((norm_x2, score2), dim=-1) + input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2} + input = { + k: v.to(device).float() if isinstance(v, torch.Tensor) else v + for k, v in input.items() + } + pred = self.net(input, test_mode=True) + + p = pred["p"] # shape: N * M + indices0 = self.match_p(p[0, :-1, :-1]) + pred = { + "matches0": indices0.unsqueeze(0), + "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0), + } + return pred + + def match_p(self, p): + score, index = torch.topk(p, k=1, dim=-1) + _, index2 = torch.topk(p, k=1, dim=-2) + mask_th, index, index2 = ( + score[:, 0] > self.conf["match_threshold"], + index[:, 0], + index2.squeeze(0), + ) + mask_mc = index2[index] == torch.arange(len(p)).to(device) + mask = mask_th & mask_mc + indices0 = torch.where(mask, index, index.new_tensor(-1)) + return indices0 + + def normalize_size(self, x, size, scale=1): + norm_fac = size.max() + return (x - size / 2 + 0.5) / (norm_fac * scale) diff --git a/imcui/hloc/matchers/sold2.py b/imcui/hloc/matchers/sold2.py new file mode 100644 index 0000000000000000000000000000000000000000..daed4f029f4fcd23771ffe4a848ed12bc0b81478 --- /dev/null +++ b/imcui/hloc/matchers/sold2.py @@ -0,0 +1,144 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID, logger +from ..utils.base_model import BaseModel + +sold2_path = Path(__file__).parent / "../../third_party/SOLD2" +sys.path.append(str(sold2_path)) + +from sold2.model.line_matcher import LineMatcher + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class SOLD2(BaseModel): + default_conf = { + "model_name": "sold2_wireframe.tar", + "match_threshold": 0.2, + "checkpoint_dir": sold2_path / "pretrained", + "detect_thresh": 0.25, + "multiscale": False, + "valid_thresh": 1e-3, + "num_blocks": 20, + "overlap_ratio": 0.5, + } + required_inputs = [ + "image0", + "image1", + ] + + # Initialize the line matcher + def _init(self, conf): + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + logger.info("Loading SOLD2 model: {}".format(model_path)) + + mode = "dynamic" # 'dynamic' or 'static' + match_config = { + "model_cfg": { + "model_name": "lcnn_simple", + "model_architecture": "simple", + # Backbone related config + "backbone": "lcnn", + "backbone_cfg": { + "input_channel": 1, # Use RGB images or grayscale images. + "depth": 4, + "num_stacks": 2, + "num_blocks": 1, + "num_classes": 5, + }, + # Junction decoder related config + "junction_decoder": "superpoint_decoder", + "junc_decoder_cfg": {}, + # Heatmap decoder related config + "heatmap_decoder": "pixel_shuffle", + "heatmap_decoder_cfg": {}, + # Descriptor decoder related config + "descriptor_decoder": "superpoint_descriptor", + "descriptor_decoder_cfg": {}, + # Shared configurations + "grid_size": 8, + "keep_border_valid": True, + # Threshold of junction detection + "detection_thresh": 0.0153846, # 1/65 + "max_num_junctions": 300, + # Threshold of heatmap detection + "prob_thresh": 0.5, + # Weighting related parameters + "weighting_policy": mode, + # [Heatmap loss] + "w_heatmap": 0.0, + "w_heatmap_class": 1, + "heatmap_loss_func": "cross_entropy", + "heatmap_loss_cfg": {"policy": mode}, + # [Heatmap consistency loss] + # [Junction loss] + "w_junc": 0.0, + "junction_loss_func": "superpoint", + "junction_loss_cfg": {"policy": mode}, + # [Descriptor loss] + "w_desc": 0.0, + "descriptor_loss_func": "regular_sampling", + "descriptor_loss_cfg": { + "dist_threshold": 8, + "grid_size": 4, + "margin": 1, + "policy": mode, + }, + }, + "line_detector_cfg": { + "detect_thresh": 0.25, # depending on your images, you might need to tune this parameter + "num_samples": 64, + "sampling_method": "local_max", + "inlier_thresh": 0.9, + "use_candidate_suppression": True, + "nms_dist_tolerance": 3.0, + "use_heatmap_refinement": True, + "heatmap_refine_cfg": { + "mode": "local", + "ratio": 0.2, + "valid_thresh": 1e-3, + "num_blocks": 20, + "overlap_ratio": 0.5, + }, + }, + "multiscale": False, + "line_matcher_cfg": { + "cross_check": True, + "num_samples": 5, + "min_dist_pts": 8, + "top_k_candidates": 10, + "grid_size": 4, + }, + } + self.net = LineMatcher( + match_config["model_cfg"], + model_path, + device, + match_config["line_detector_cfg"], + match_config["line_matcher_cfg"], + match_config["multiscale"], + ) + + def _forward(self, data): + img0 = data["image0"] + img1 = data["image1"] + pred = self.net([img0, img1]) + line_seg1 = pred["line_segments"][0] + line_seg2 = pred["line_segments"][1] + matches = pred["matches"] + + valid_matches = matches != -1 + match_indices = matches[valid_matches] + matched_lines1 = line_seg1[valid_matches][:, :, ::-1] + matched_lines2 = line_seg2[match_indices][:, :, ::-1] + + pred["raw_lines0"], pred["raw_lines1"] = line_seg1, line_seg2 + pred["lines0"], pred["lines1"] = matched_lines1, matched_lines2 + pred = {**pred, **data} + return pred diff --git a/imcui/hloc/matchers/superglue.py b/imcui/hloc/matchers/superglue.py new file mode 100644 index 0000000000000000000000000000000000000000..6fae344e8dfe0fd46f090c6915036e5f9c09a635 --- /dev/null +++ b/imcui/hloc/matchers/superglue.py @@ -0,0 +1,33 @@ +import sys +from pathlib import Path + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from SuperGluePretrainedNetwork.models.superglue import ( # noqa: E402 + SuperGlue as SG, +) + + +class SuperGlue(BaseModel): + default_conf = { + "weights": "outdoor", + "sinkhorn_iterations": 100, + "match_threshold": 0.2, + } + required_inputs = [ + "image0", + "keypoints0", + "scores0", + "descriptors0", + "image1", + "keypoints1", + "scores1", + "descriptors1", + ] + + def _init(self, conf): + self.net = SG(conf) + + def _forward(self, data): + return self.net(data) diff --git a/imcui/hloc/matchers/topicfm.py b/imcui/hloc/matchers/topicfm.py new file mode 100644 index 0000000000000000000000000000000000000000..5c99adc740e82cdcd644e18fff450a4efeaaf9bc --- /dev/null +++ b/imcui/hloc/matchers/topicfm.py @@ -0,0 +1,60 @@ +import sys +from pathlib import Path + +import torch + +from .. import MODEL_REPO_ID + +from ..utils.base_model import BaseModel + +sys.path.append(str(Path(__file__).parent / "../../third_party")) +from TopicFM.src import get_model_cfg +from TopicFM.src.models.topic_fm import TopicFM as _TopicFM + +topicfm_path = Path(__file__).parent / "../../third_party/TopicFM" + + +class TopicFM(BaseModel): + default_conf = { + "weights": "outdoor", + "model_name": "model_best.ckpt", + "match_threshold": 0.2, + "n_sampling_topics": 4, + "max_keypoints": -1, + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + _conf = dict(get_model_cfg()) + _conf["match_coarse"]["thr"] = conf["match_threshold"] + _conf["coarse"]["n_samples"] = conf["n_sampling_topics"] + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + self.net = _TopicFM(config=_conf) + ckpt_dict = torch.load(model_path, map_location="cpu") + self.net.load_state_dict(ckpt_dict["state_dict"]) + + def _forward(self, data): + data_ = { + "image0": data["image0"], + "image1": data["image1"], + } + self.net(data_) + pred = { + "keypoints0": data_["mkpts0_f"], + "keypoints1": data_["mkpts1_f"], + "mconf": data_["mconf"], + } + scores = data_["mconf"] + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + scores = scores[keep] + pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + scores, + ) + return pred diff --git a/imcui/hloc/matchers/xfeat_dense.py b/imcui/hloc/matchers/xfeat_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..7b6956636c126d1d482c1dc3ff18e23978a9163b --- /dev/null +++ b/imcui/hloc/matchers/xfeat_dense.py @@ -0,0 +1,54 @@ +import torch + +from .. import logger + +from ..utils.base_model import BaseModel + + +class XFeatDense(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": 8000, + } + required_inputs = [ + "image0", + "image1", + ] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(dense) model done.") + + def _forward(self, data): + # Compute coarse feats + out0 = self.net.detectAndComputeDense( + data["image0"], top_k=self.conf["max_keypoints"] + ) + out1 = self.net.detectAndComputeDense( + data["image1"], top_k=self.conf["max_keypoints"] + ) + + # Match batches of pairs + idxs_list = self.net.batch_match(out0["descriptors"], out1["descriptors"]) + B = len(data["image0"]) + + # Refine coarse matches + # this part is harder to batch, currently iterate + matches = [] + for b in range(B): + matches.append( + self.net.refine_matches(out0, out1, matches=idxs_list, batch_idx=b) + ) + # we use results from one batch + matches = matches[0] + pred = { + "keypoints0": matches[:, :2], + "keypoints1": matches[:, 2:], + "mconf": torch.ones_like(matches[:, 0]), + } + return pred diff --git a/imcui/hloc/matchers/xfeat_lightglue.py b/imcui/hloc/matchers/xfeat_lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..6041d29e1473cc2b0cebc3ca64f2b01df2753f06 --- /dev/null +++ b/imcui/hloc/matchers/xfeat_lightglue.py @@ -0,0 +1,48 @@ +import torch + +from .. import logger + +from ..utils.base_model import BaseModel + + +class XFeatLightGlue(BaseModel): + default_conf = { + "keypoint_threshold": 0.005, + "max_keypoints": 8000, + } + required_inputs = [ + "image0", + "image1", + ] + + def _init(self, conf): + self.net = torch.hub.load( + "verlab/accelerated_features", + "XFeat", + pretrained=True, + top_k=self.conf["max_keypoints"], + ) + logger.info("Load XFeat(dense) model done.") + + def _forward(self, data): + # we use results from one batch + im0 = data["image0"] + im1 = data["image1"] + # Compute coarse feats + out0 = self.net.detectAndCompute(im0, top_k=self.conf["max_keypoints"])[0] + out1 = self.net.detectAndCompute(im1, top_k=self.conf["max_keypoints"])[0] + out0.update({"image_size": (im0.shape[-1], im0.shape[-2])}) # W H + out1.update({"image_size": (im1.shape[-1], im1.shape[-2])}) # W H + pred = self.net.match_lighterglue(out0, out1) + if len(pred) == 3: + mkpts_0, mkpts_1, _ = pred + else: + mkpts_0, mkpts_1 = pred + mkpts_0 = torch.from_numpy(mkpts_0) # n x 2 + mkpts_1 = torch.from_numpy(mkpts_1) # n x 2 + pred = { + "keypoints0": mkpts_0, + "keypoints1": mkpts_1, + "mconf": torch.ones_like(mkpts_0[:, 0]), + } + return pred diff --git a/imcui/hloc/matchers/xoftr.py b/imcui/hloc/matchers/xoftr.py new file mode 100644 index 0000000000000000000000000000000000000000..135f67f811468a13fb172bf06115aafb3084ccfb --- /dev/null +++ b/imcui/hloc/matchers/xoftr.py @@ -0,0 +1,90 @@ +import sys +import warnings +from pathlib import Path + +import torch + +from .. import DEVICE, MODEL_REPO_ID, logger + +tp_path = Path(__file__).parent / "../../third_party" +sys.path.append(str(tp_path)) + +from XoFTR.src.config.default import get_cfg_defaults +from XoFTR.src.utils.misc import lower_config +from XoFTR.src.xoftr import XoFTR as XoFTR_ + + +from ..utils.base_model import BaseModel + + +class XoFTR(BaseModel): + default_conf = { + "model_name": "weights_xoftr_640.ckpt", + "match_threshold": 0.3, + "max_keypoints": -1, + } + required_inputs = ["image0", "image1"] + + def _init(self, conf): + # Get default configurations + config_ = get_cfg_defaults(inference=True) + config_ = lower_config(config_) + + # Coarse level threshold + config_["xoftr"]["match_coarse"]["thr"] = self.conf["match_threshold"] + + # Fine level threshold + config_["xoftr"]["fine"]["thr"] = 0.1 # Default 0.1 + + # It is posseble to get denser matches + # If True, xoftr returns all fine-level matches for each fine-level window (at 1/2 resolution) + config_["xoftr"]["fine"]["denser"] = False # Default False + + # XoFTR model + matcher = XoFTR_(config=config_["xoftr"]) + + model_path = self._download_model( + repo_id=MODEL_REPO_ID, + filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]), + ) + + # Load model + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + matcher.load_state_dict(state_dict, strict=True) + matcher = matcher.eval().to(DEVICE) + self.net = matcher + logger.info(f"Loaded XoFTR with weights {conf['model_name']}") + + def _forward(self, data): + # For consistency with hloc pairs, we refine kpts in image0! + rename = { + "keypoints0": "keypoints1", + "keypoints1": "keypoints0", + "image0": "image1", + "image1": "image0", + "mask0": "mask1", + "mask1": "mask0", + } + data_ = {rename[k]: v for k, v in data.items()} + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pred = self.net(data_) + pred = { + "keypoints0": data_["mkpts0_f"], + "keypoints1": data_["mkpts1_f"], + } + scores = data_["mconf_f"] + + top_k = self.conf["max_keypoints"] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + pred["keypoints0"], pred["keypoints1"] = ( + pred["keypoints0"][keep], + pred["keypoints1"][keep], + ) + scores = scores[keep] + + # Switch back indices + pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} + pred["scores"] = scores + return pred diff --git a/imcui/hloc/pairs_from_covisibility.py b/imcui/hloc/pairs_from_covisibility.py new file mode 100644 index 0000000000000000000000000000000000000000..49f3e57f2bd1aec20e12ecca6df8f94a68b7fd4e --- /dev/null +++ b/imcui/hloc/pairs_from_covisibility.py @@ -0,0 +1,60 @@ +import argparse +from collections import defaultdict +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from . import logger +from .utils.read_write_model import read_model + + +def main(model, output, num_matched): + logger.info("Reading the COLMAP model...") + cameras, images, points3D = read_model(model) + + logger.info("Extracting image pairs from covisibility info...") + pairs = [] + for image_id, image in tqdm(images.items()): + matched = image.point3D_ids != -1 + points3D_covis = image.point3D_ids[matched] + + covis = defaultdict(int) + for point_id in points3D_covis: + for image_covis_id in points3D[point_id].image_ids: + if image_covis_id != image_id: + covis[image_covis_id] += 1 + + if len(covis) == 0: + logger.info(f"Image {image_id} does not have any covisibility.") + continue + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= num_matched: + top_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + # get covisible image ids with top k number of common matches + ind_top = np.argpartition(covis_num, -num_matched) + ind_top = ind_top[-num_matched:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + top_covis_ids = [covis_ids[i] for i in ind_top] + assert covis_num[ind_top[0]] == np.max(covis_num) + + for i in top_covis_ids: + pair = (image.name, images[i].name) + pairs.append(pair) + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--num_matched", required=True, type=int) + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/pairs_from_exhaustive.py b/imcui/hloc/pairs_from_exhaustive.py new file mode 100644 index 0000000000000000000000000000000000000000..0d54ed1dcdbb16d490fcadf9ac2577fd064c3828 --- /dev/null +++ b/imcui/hloc/pairs_from_exhaustive.py @@ -0,0 +1,64 @@ +import argparse +import collections.abc as collections +from pathlib import Path +from typing import List, Optional, Union + +from . import logger +from .utils.io import list_h5_names +from .utils.parsers import parse_image_lists + + +def main( + output: Path, + image_list: Optional[Union[Path, List[str]]] = None, + features: Optional[Path] = None, + ref_list: Optional[Union[Path, List[str]]] = None, + ref_features: Optional[Path] = None, +): + if image_list is not None: + if isinstance(image_list, (str, Path)): + names_q = parse_image_lists(image_list) + elif isinstance(image_list, collections.Iterable): + names_q = list(image_list) + else: + raise ValueError(f"Unknown type for image list: {image_list}") + elif features is not None: + names_q = list_h5_names(features) + else: + raise ValueError("Provide either a list of images or a feature file.") + + self_matching = False + if ref_list is not None: + if isinstance(ref_list, (str, Path)): + names_ref = parse_image_lists(ref_list) + elif isinstance(image_list, collections.Iterable): + names_ref = list(ref_list) + else: + raise ValueError(f"Unknown type for reference image list: {ref_list}") + elif ref_features is not None: + names_ref = list_h5_names(ref_features) + else: + self_matching = True + names_ref = names_q + + pairs = [] + for i, n1 in enumerate(names_q): + for j, n2 in enumerate(names_ref): + if self_matching and j <= i: + continue + pairs.append((n1, n2)) + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--image_list", type=Path) + parser.add_argument("--features", type=Path) + parser.add_argument("--ref_list", type=Path) + parser.add_argument("--ref_features", type=Path) + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/pairs_from_poses.py b/imcui/hloc/pairs_from_poses.py new file mode 100644 index 0000000000000000000000000000000000000000..83ee1b8cce2b680fac9a4de35d68c5f234092361 --- /dev/null +++ b/imcui/hloc/pairs_from_poses.py @@ -0,0 +1,68 @@ +import argparse +from pathlib import Path + +import numpy as np +import scipy.spatial + +from . import logger +from .pairs_from_retrieval import pairs_from_score_matrix +from .utils.read_write_model import read_images_binary + +DEFAULT_ROT_THRESH = 30 # in degrees + + +def get_pairwise_distances(images): + ids = np.array(list(images.keys())) + Rs = [] + ts = [] + for id_ in ids: + image = images[id_] + R = image.qvec2rotmat() + t = image.tvec + Rs.append(R) + ts.append(t) + Rs = np.stack(Rs, 0) + ts = np.stack(ts, 0) + + # Invert the poses from world-to-camera to camera-to-world. + Rs = Rs.transpose(0, 2, 1) + ts = -(Rs @ ts[:, :, None])[:, :, 0] + + dist = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(ts)) + + # Instead of computing the angle between two camera orientations, + # we compute the angle between the principal axes, as two images rotated + # around their principal axis still observe the same scene. + axes = Rs[:, :, -1] + dots = np.einsum("mi,ni->mn", axes, axes, optimize=True) + dR = np.rad2deg(np.arccos(np.clip(dots, -1.0, 1.0))) + + return ids, dist, dR + + +def main(model, output, num_matched, rotation_threshold=DEFAULT_ROT_THRESH): + logger.info("Reading the COLMAP model...") + images = read_images_binary(model / "images.bin") + + logger.info(f"Obtaining pairwise distances between {len(images)} images...") + ids, dist, dR = get_pairwise_distances(images) + scores = -dist + + invalid = dR >= rotation_threshold + np.fill_diagonal(invalid, True) + pairs = pairs_from_score_matrix(scores, invalid, num_matched) + pairs = [(images[ids[i]].name, images[ids[j]].name) for i, j in pairs] + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join(p) for p in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--num_matched", required=True, type=int) + parser.add_argument("--rotation_threshold", default=DEFAULT_ROT_THRESH, type=float) + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/pairs_from_retrieval.py b/imcui/hloc/pairs_from_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..323368011086b10065aba177a360284f558904e8 --- /dev/null +++ b/imcui/hloc/pairs_from_retrieval.py @@ -0,0 +1,133 @@ +import argparse +import collections.abc as collections +from pathlib import Path +from typing import Optional + +import h5py +import numpy as np +import torch + +from . import logger +from .utils.io import list_h5_names +from .utils.parsers import parse_image_lists +from .utils.read_write_model import read_images_binary + + +def parse_names(prefix, names, names_all): + if prefix is not None: + if not isinstance(prefix, str): + prefix = tuple(prefix) + names = [n for n in names_all if n.startswith(prefix)] + if len(names) == 0: + raise ValueError(f"Could not find any image with the prefix `{prefix}`.") + elif names is not None: + if isinstance(names, (str, Path)): + names = parse_image_lists(names) + elif isinstance(names, collections.Iterable): + names = list(names) + else: + raise ValueError( + f"Unknown type of image list: {names}." + "Provide either a list or a path to a list file." + ) + else: + names = names_all + return names + + +def get_descriptors(names, path, name2idx=None, key="global_descriptor"): + if name2idx is None: + with h5py.File(str(path), "r", libver="latest") as fd: + desc = [fd[n][key].__array__() for n in names] + else: + desc = [] + for n in names: + with h5py.File(str(path[name2idx[n]]), "r", libver="latest") as fd: + desc.append(fd[n][key].__array__()) + return torch.from_numpy(np.stack(desc, 0)).float() + + +def pairs_from_score_matrix( + scores: torch.Tensor, + invalid: np.array, + num_select: int, + min_score: Optional[float] = None, +): + assert scores.shape == invalid.shape + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + invalid = torch.from_numpy(invalid).to(scores.device) + if min_score is not None: + invalid |= scores < min_score + scores.masked_fill_(invalid, float("-inf")) + + topk = torch.topk(scores, num_select, dim=1) + indices = topk.indices.cpu().numpy() + valid = topk.values.isfinite().cpu().numpy() + + pairs = [] + for i, j in zip(*np.where(valid)): + pairs.append((i, indices[i, j])) + return pairs + + +def main( + descriptors, + output, + num_matched, + query_prefix=None, + query_list=None, + db_prefix=None, + db_list=None, + db_model=None, + db_descriptors=None, +): + logger.info("Extracting image pairs from a retrieval database.") + + # We handle multiple reference feature files. + # We only assume that names are unique among them and map names to files. + if db_descriptors is None: + db_descriptors = descriptors + if isinstance(db_descriptors, (Path, str)): + db_descriptors = [db_descriptors] + name2db = {n: i for i, p in enumerate(db_descriptors) for n in list_h5_names(p)} + db_names_h5 = list(name2db.keys()) + query_names_h5 = list_h5_names(descriptors) + + if db_model: + images = read_images_binary(db_model / "images.bin") + db_names = [i.name for i in images.values()] + else: + db_names = parse_names(db_prefix, db_list, db_names_h5) + if len(db_names) == 0: + raise ValueError("Could not find any database image.") + query_names = parse_names(query_prefix, query_list, query_names_h5) + + device = "cuda" if torch.cuda.is_available() else "cpu" + db_desc = get_descriptors(db_names, db_descriptors, name2db) + query_desc = get_descriptors(query_names, descriptors) + sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device)) + + # Avoid self-matching + self = np.array(query_names)[:, None] == np.array(db_names)[None] + pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) + pairs = [(query_names[i], db_names[j]) for i, j in pairs] + + logger.info(f"Found {len(pairs)} pairs.") + with open(output, "w") as f: + f.write("\n".join(" ".join([i, j]) for i, j in pairs)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--descriptors", type=Path, required=True) + parser.add_argument("--output", type=Path, required=True) + parser.add_argument("--num_matched", type=int, required=True) + parser.add_argument("--query_prefix", type=str, nargs="+") + parser.add_argument("--query_list", type=Path) + parser.add_argument("--db_prefix", type=str, nargs="+") + parser.add_argument("--db_list", type=Path) + parser.add_argument("--db_model", type=Path) + parser.add_argument("--db_descriptors", type=Path) + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/pipelines/4Seasons/README.md b/imcui/hloc/pipelines/4Seasons/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ad23ac8348ae9f0963611bc9a342240d5ae97255 --- /dev/null +++ b/imcui/hloc/pipelines/4Seasons/README.md @@ -0,0 +1,43 @@ +# 4Seasons dataset + +This pipeline localizes sequences from the [4Seasons dataset](https://arxiv.org/abs/2009.06364) and can reproduce our winning submission to the challenge of the [ECCV 2020 Workshop on Map-based Localization for Autonomous Driving](https://sites.google.com/view/mlad-eccv2020/home). + +## Installation + +Download the sequences from the [challenge webpage](https://sites.google.com/view/mlad-eccv2020/challenge) and run: +```bash +unzip recording_2020-04-07_10-20-32.zip -d datasets/4Seasons/reference +unzip recording_2020-03-24_17-36-22.zip -d datasets/4Seasons/training +unzip recording_2020-03-03_12-03-23.zip -d datasets/4Seasons/validation +unzip recording_2020-03-24_17-45-31.zip -d datasets/4Seasons/test0 +unzip recording_2020-04-23_19-37-00.zip -d datasets/4Seasons/test1 +``` +Note that the provided scripts might modify the dataset files by deleting unused images to speed up the feature extraction + +## Pipeline + +The process is presented in our workshop talk, whose recording can be found [here](https://youtu.be/M-X6HX1JxYk?t=5245). + +We first triangulate a 3D model from the given poses of the reference sequence: +```bash +python3 -m hloc.pipelines.4Seasons.prepare_reference +``` + +We then relocalize a given sequence: +```bash +python3 -m hloc.pipelines.4Seasons.localize --sequence [training|validation|test0|test1] +``` + +The final submission files can be found in `outputs/4Seasons/submission_hloc+superglue/`. The script will also evaluate these results if the training or validation sequences are selected. + +## Results + +We evaluate the localization recall at distance thresholds 0.1m, 0.2m, and 0.5m. + +| Methods | test0 | test1 | +| -------------------- | ---------------------- | ---------------------- | +| **hloc + SuperGlue** | **91.8 / 97.7 / 99.2** | **67.3 / 93.5 / 98.7** | +| Baseline SuperGlue | 21.2 / 33.9 / 60.0 | 12.4 / 26.5 / 54.4 | +| Baseline R2D2 | 21.5 / 33.1 / 53.0 | 12.3 / 23.7 / 42.0 | +| Baseline D2Net | 12.5 / 29.3 / 56.7 | 7.5 / 21.4 / 47.7 | +| Baseline SuperPoint | 15.5 / 27.5 / 47.5 | 9.0 / 19.4 / 36.4 | diff --git a/imcui/hloc/pipelines/4Seasons/__init__.py b/imcui/hloc/pipelines/4Seasons/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/pipelines/4Seasons/localize.py b/imcui/hloc/pipelines/4Seasons/localize.py new file mode 100644 index 0000000000000000000000000000000000000000..50ed957bcf159915d0be98fa9b54c8bce0059b56 --- /dev/null +++ b/imcui/hloc/pipelines/4Seasons/localize.py @@ -0,0 +1,89 @@ +import argparse +from pathlib import Path + +from ... import extract_features, localize_sfm, logger, match_features +from .utils import ( + delete_unused_images, + evaluate_submission, + generate_localization_pairs, + generate_query_lists, + get_timestamps, + prepare_submission, +) + +relocalization_files = { + "training": "RelocalizationFilesTrain//relocalizationFile_recording_2020-03-24_17-36-22.txt", # noqa: E501 + "validation": "RelocalizationFilesVal/relocalizationFile_recording_2020-03-03_12-03-23.txt", # noqa: E501 + "test0": "RelocalizationFilesTest/relocalizationFile_recording_2020-03-24_17-45-31_*.txt", # noqa: E501 + "test1": "RelocalizationFilesTest/relocalizationFile_recording_2020-04-23_19-37-00_*.txt", # noqa: E501 +} + +parser = argparse.ArgumentParser() +parser.add_argument( + "--sequence", + type=str, + required=True, + choices=["training", "validation", "test0", "test1"], + help="Sequence to be relocalized.", +) +parser.add_argument( + "--dataset", + type=Path, + default="datasets/4Seasons", + help="Path to the dataset, default: %(default)s", +) +parser.add_argument( + "--outputs", + type=Path, + default="outputs/4Seasons", + help="Path to the output directory, default: %(default)s", +) +args = parser.parse_args() +sequence = args.sequence + +data_dir = args.dataset +ref_dir = data_dir / "reference" +assert ref_dir.exists(), f"{ref_dir} does not exist" +seq_dir = data_dir / sequence +assert seq_dir.exists(), f"{seq_dir} does not exist" +seq_images = seq_dir / "undistorted_images" +reloc = ref_dir / relocalization_files[sequence] + +output_dir = args.outputs +output_dir.mkdir(exist_ok=True, parents=True) +query_list = output_dir / f"{sequence}_queries_with_intrinsics.txt" +ref_pairs = output_dir / "pairs-db-dist20.txt" +ref_sfm = output_dir / "sfm_superpoint+superglue" +results_path = output_dir / f"localization_{sequence}_hloc+superglue.txt" +submission_dir = output_dir / "submission_hloc+superglue" + +num_loc_pairs = 10 +loc_pairs = output_dir / f"pairs-query-{sequence}-dist{num_loc_pairs}.txt" + +fconf = extract_features.confs["superpoint_max"] +mconf = match_features.confs["superglue"] + +# Not all query images that are used for the evaluation +# To save time in feature extraction, we delete unsused images. +timestamps = get_timestamps(reloc, 1) +delete_unused_images(seq_images, timestamps) + +# Generate a list of query images with their intrinsics. +generate_query_lists(timestamps, seq_dir, query_list) + +# Generate the localization pairs from the given reference frames. +generate_localization_pairs(sequence, reloc, num_loc_pairs, ref_pairs, loc_pairs) + +# Extract, match, amd localize. +ffile = extract_features.main(fconf, seq_images, output_dir) +mfile = match_features.main(mconf, loc_pairs, fconf["output"], output_dir) +localize_sfm.main(ref_sfm, query_list, loc_pairs, ffile, mfile, results_path) + +# Convert the absolute poses to relative poses with the reference frames. +submission_dir.mkdir(exist_ok=True) +prepare_submission(results_path, reloc, ref_dir / "poses.txt", submission_dir) + +# If not a test sequence: evaluation the localization accuracy +if "test" not in sequence: + logger.info("Evaluating the relocalization submission...") + evaluate_submission(submission_dir, reloc) diff --git a/imcui/hloc/pipelines/4Seasons/prepare_reference.py b/imcui/hloc/pipelines/4Seasons/prepare_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..f47aee778ba24ef89a1cc4418f5db9cfab209b9d --- /dev/null +++ b/imcui/hloc/pipelines/4Seasons/prepare_reference.py @@ -0,0 +1,51 @@ +import argparse +from pathlib import Path + +from ... import extract_features, match_features, pairs_from_poses, triangulation +from .utils import build_empty_colmap_model, delete_unused_images, get_timestamps + +parser = argparse.ArgumentParser() +parser.add_argument( + "--dataset", + type=Path, + default="datasets/4Seasons", + help="Path to the dataset, default: %(default)s", +) +parser.add_argument( + "--outputs", + type=Path, + default="outputs/4Seasons", + help="Path to the output directory, default: %(default)s", +) +args = parser.parse_args() + +ref_dir = args.dataset / "reference" +assert ref_dir.exists(), f"{ref_dir} does not exist" +ref_images = ref_dir / "undistorted_images" + +output_dir = args.outputs +output_dir.mkdir(exist_ok=True, parents=True) +ref_sfm_empty = output_dir / "sfm_reference_empty" +ref_sfm = output_dir / "sfm_superpoint+superglue" + +num_ref_pairs = 20 +ref_pairs = output_dir / f"pairs-db-dist{num_ref_pairs}.txt" + +fconf = extract_features.confs["superpoint_max"] +mconf = match_features.confs["superglue"] + +# Only reference images that have a pose are used in the pipeline. +# To save time in feature extraction, we delete unsused images. +delete_unused_images(ref_images, get_timestamps(ref_dir / "poses.txt", 0)) + +# Build an empty COLMAP model containing only camera and images +# from the provided poses and intrinsics. +build_empty_colmap_model(ref_dir, ref_sfm_empty) + +# Match reference images that are spatially close. +pairs_from_poses.main(ref_sfm_empty, ref_pairs, num_ref_pairs) + +# Extract, match, and triangulate the reference SfM model. +ffile = extract_features.main(fconf, ref_images, output_dir) +mfile = match_features.main(mconf, ref_pairs, fconf["output"], output_dir) +triangulation.main(ref_sfm, ref_sfm_empty, ref_images, ref_pairs, ffile, mfile) diff --git a/imcui/hloc/pipelines/4Seasons/utils.py b/imcui/hloc/pipelines/4Seasons/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5aace9dd9a31b9c39a58691c0c23795313bc462 --- /dev/null +++ b/imcui/hloc/pipelines/4Seasons/utils.py @@ -0,0 +1,231 @@ +import glob +import logging +import os +from pathlib import Path + +import numpy as np + +from ...utils.parsers import parse_retrieval +from ...utils.read_write_model import ( + Camera, + Image, + qvec2rotmat, + rotmat2qvec, + write_model, +) + +logger = logging.getLogger(__name__) + + +def get_timestamps(files, idx): + """Extract timestamps from a pose or relocalization file.""" + lines = [] + for p in files.parent.glob(files.name): + with open(p) as f: + lines += f.readlines() + timestamps = set() + for line in lines: + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + ts = line.replace(",", " ").split()[idx] + timestamps.add(ts) + return timestamps + + +def delete_unused_images(root, timestamps): + """Delete all images in root if they are not contained in timestamps.""" + images = glob.glob((root / "**/*.png").as_posix(), recursive=True) + deleted = 0 + for image in images: + ts = Path(image).stem + if ts not in timestamps: + os.remove(image) + deleted += 1 + logger.info(f"Deleted {deleted} images in {root}.") + + +def camera_from_calibration_file(id_, path): + """Create a COLMAP camera from an MLAD calibration file.""" + with open(path, "r") as f: + data = f.readlines() + model, fx, fy, cx, cy = data[0].split()[:5] + width, height = data[1].split() + assert model == "Pinhole" + model_name = "PINHOLE" + params = [float(i) for i in [fx, fy, cx, cy]] + camera = Camera( + id=id_, model=model_name, width=int(width), height=int(height), params=params + ) + return camera + + +def parse_poses(path, colmap=False): + """Parse a list of poses in COLMAP or MLAD quaternion convention.""" + poses = [] + with open(path) as f: + for line in f.readlines(): + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + data = line.replace(",", " ").split() + ts, p = data[0], np.array(data[1:], float) + if colmap: + q, t = np.split(p, [4]) + else: + t, q = np.split(p, [3]) + q = q[[3, 0, 1, 2]] # xyzw to wxyz + R = qvec2rotmat(q) + poses.append((ts, R, t)) + return poses + + +def parse_relocalization(path, has_poses=False): + """Parse a relocalization file, possibly with poses.""" + reloc = [] + with open(path) as f: + for line in f.readlines(): + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + data = line.replace(",", " ").split() + out = data[:2] # ref_ts, q_ts + if has_poses: + assert len(data) == 9 + t, q = np.split(np.array(data[2:], float), [3]) + q = q[[3, 0, 1, 2]] # xyzw to wxyz + R = qvec2rotmat(q) + out += [R, t] + reloc.append(out) + return reloc + + +def build_empty_colmap_model(root, sfm_dir): + """Build a COLMAP model with images and cameras only.""" + calibration = "Calibration/undistorted_calib_{}.txt" + cam0 = camera_from_calibration_file(0, root / calibration.format(0)) + cam1 = camera_from_calibration_file(1, root / calibration.format(1)) + cameras = {0: cam0, 1: cam1} + + T_0to1 = np.loadtxt(root / "Calibration/undistorted_calib_stereo.txt") + poses = parse_poses(root / "poses.txt") + images = {} + id_ = 0 + for ts, R_cam0_to_w, t_cam0_to_w in poses: + R_w_to_cam0 = R_cam0_to_w.T + t_w_to_cam0 = -(R_w_to_cam0 @ t_cam0_to_w) + + R_w_to_cam1 = T_0to1[:3, :3] @ R_w_to_cam0 + t_w_to_cam1 = T_0to1[:3, :3] @ t_w_to_cam0 + T_0to1[:3, 3] + + for idx, (R_w_to_cam, t_w_to_cam) in enumerate( + zip([R_w_to_cam0, R_w_to_cam1], [t_w_to_cam0, t_w_to_cam1]) + ): + image = Image( + id=id_, + qvec=rotmat2qvec(R_w_to_cam), + tvec=t_w_to_cam, + camera_id=idx, + name=f"cam{idx}/{ts}.png", + xys=np.zeros((0, 2), float), + point3D_ids=np.full(0, -1, int), + ) + images[id_] = image + id_ += 1 + + sfm_dir.mkdir(exist_ok=True, parents=True) + write_model(cameras, images, {}, path=str(sfm_dir), ext=".bin") + + +def generate_query_lists(timestamps, seq_dir, out_path): + """Create a list of query images with intrinsics from timestamps.""" + cam0 = camera_from_calibration_file( + 0, seq_dir / "Calibration/undistorted_calib_0.txt" + ) + intrinsics = [cam0.model, cam0.width, cam0.height] + cam0.params + intrinsics = [str(p) for p in intrinsics] + data = map(lambda ts: " ".join([f"cam0/{ts}.png"] + intrinsics), timestamps) + with open(out_path, "w") as f: + f.write("\n".join(data)) + + +def generate_localization_pairs(sequence, reloc, num, ref_pairs, out_path): + """Create the matching pairs for the localization. + We simply lookup the corresponding reference frame + and extract its `num` closest frames from the existing pair list. + """ + if "test" in sequence: + # hard pairs will be overwritten by easy ones if available + relocs = [str(reloc).replace("*", d) for d in ["hard", "moderate", "easy"]] + else: + relocs = [reloc] + query_to_ref_ts = {} + for reloc in relocs: + with open(reloc, "r") as f: + for line in f.readlines(): + line = line.rstrip("\n") + if line[0] == "#" or line == "": + continue + ref_ts, q_ts = line.split()[:2] + query_to_ref_ts[q_ts] = ref_ts + + ts_to_name = "cam0/{}.png".format + ref_pairs = parse_retrieval(ref_pairs) + loc_pairs = [] + for q_ts, ref_ts in query_to_ref_ts.items(): + ref_name = ts_to_name(ref_ts) + selected = [ref_name] + ref_pairs[ref_name][: num - 1] + loc_pairs.extend([" ".join((ts_to_name(q_ts), s)) for s in selected]) + with open(out_path, "w") as f: + f.write("\n".join(loc_pairs)) + + +def prepare_submission(results, relocs, poses_path, out_dir): + """Obtain relative poses from estimated absolute and reference poses.""" + gt_poses = parse_poses(poses_path) + all_T_ref0_to_w = {ts: (R, t) for ts, R, t in gt_poses} + + pred_poses = parse_poses(results, colmap=True) + all_T_w_to_q0 = {Path(name).stem: (R, t) for name, R, t in pred_poses} + + for reloc in relocs.parent.glob(relocs.name): + relative_poses = [] + reloc_ts = parse_relocalization(reloc) + for ref_ts, q_ts in reloc_ts: + R_w_to_q0, t_w_to_q0 = all_T_w_to_q0[q_ts] + R_ref0_to_w, t_ref0_to_w = all_T_ref0_to_w[ref_ts] + + R_ref0_to_q0 = R_w_to_q0 @ R_ref0_to_w + t_ref0_to_q0 = R_w_to_q0 @ t_ref0_to_w + t_w_to_q0 + + tvec = t_ref0_to_q0.tolist() + qvec = rotmat2qvec(R_ref0_to_q0)[[1, 2, 3, 0]] # wxyz to xyzw + + out = [ref_ts, q_ts] + list(map(str, tvec)) + list(map(str, qvec)) + relative_poses.append(" ".join(out)) + + out_path = out_dir / reloc.name + with open(out_path, "w") as f: + f.write("\n".join(relative_poses)) + logger.info(f"Submission file written to {out_path}.") + + +def evaluate_submission(submission_dir, relocs, ths=[0.1, 0.2, 0.5]): + """Compute the relocalization recall from predicted and ground truth poses.""" + for reloc in relocs.parent.glob(relocs.name): + poses_gt = parse_relocalization(reloc, has_poses=True) + poses_pred = parse_relocalization(submission_dir / reloc.name, has_poses=True) + poses_pred = {(ref_ts, q_ts): (R, t) for ref_ts, q_ts, R, t in poses_pred} + + error = [] + for ref_ts, q_ts, R_gt, t_gt in poses_gt: + R, t = poses_pred[(ref_ts, q_ts)] + e = np.linalg.norm(t - t_gt) + error.append(e) + + error = np.array(error) + recall = [np.mean(error <= th) for th in ths] + s = f"Relocalization evaluation {submission_dir.name}/{reloc.name}\n" + s += " / ".join([f"{th:>7}m" for th in ths]) + "\n" + s += " / ".join([f"{100*r:>7.3f}%" for r in recall]) + logger.info(s) diff --git a/imcui/hloc/pipelines/7Scenes/README.md b/imcui/hloc/pipelines/7Scenes/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2124779c43ec8d1ffc552e07790d39c3578526a9 --- /dev/null +++ b/imcui/hloc/pipelines/7Scenes/README.md @@ -0,0 +1,65 @@ +# 7Scenes dataset + +## Installation + +Download the images from the [7Scenes project page](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/): +```bash +export dataset=datasets/7scenes +for scene in chess fire heads office pumpkin redkitchen stairs; \ +do wget http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/$scene.zip -P $dataset \ +&& unzip $dataset/$scene.zip -d $dataset && unzip $dataset/$scene/'*.zip' -d $dataset/$scene; done +``` + +Download the SIFT SfM models and DenseVLAD image pairs, courtesy of Torsten Sattler: +```bash +function download { +wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$1" -O $2 && rm -rf /tmp/cookies.txt +unzip $2 -d $dataset && rm $2; +} +download 1cu6KUR7WHO7G4EO49Qi3HEKU6n_yYDjb $dataset/7scenes_sfm_triangulated.zip +download 1IbS2vLmxr1N0f3CEnd_wsYlgclwTyvB1 $dataset/7scenes_densevlad_retrieval_top_10.zip +``` + +Download the rendered depth maps, courtesy of Eric Brachmann for [DSAC\*](https://github.com/vislearn/dsacstar): +```bash +wget https://heidata.uni-heidelberg.de/api/access/datafile/4037 -O $dataset/7scenes_rendered_depth.tar.gz +mkdir $dataset/depth/ +tar xzf $dataset/7scenes_rendered_depth.tar.gz -C $dataset/depth/ && rm $dataset/7scenes_rendered_depth.tar.gz +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.7Scenes.pipeline [--use_dense_depth] +``` +By default, hloc triangulates a sparse point cloud that can be noisy in indoor environements due to image noise and lack of texture. With the flag `--use_dense_depth`, the pipeline improves the accuracy of the sparse point cloud using dense depth maps provided by the dataset. The original depth maps captured by the RGBD sensor are miscalibrated, so we use depth maps rendered from the mesh obtained by fusing the RGBD data. + +## Results +We report the median error in translation/rotation in cm/deg over all scenes: +| Method \ Scene | Chess | Fire | Heads | Office | Pumpkin | Kitchen | Stairs | +| ------------------------------- | -------------- | -------------- | -------------- | -------------- | -------------- | -------------- | ---------- | +| Active Search | 3/0.87 | **2**/1.01 | **1**/0.82 | 4/1.15 | 7/1.69 | 5/1.72 | 4/**1.01** | +| DSAC* | **2**/1.10 | **2**/1.24 | **1**/1.82 | **3**/1.15 | **4**/1.34 | 4/1.68 | **3**/1.16 | +| **SuperPoint+SuperGlue** (sfm) | **2**/0.84 | **2**/0.93 | **1**/**0.74** | **3**/0.92 | 5/1.27 | 4/1.40 | 5/1.47 | +| **SuperPoint+SuperGlue** (RGBD) | **2**/**0.80** | **2**/**0.77** | **1**/0.79 | **3**/**0.80** | **4**/**1.07** | **3**/**1.13** | 4/1.15 | + +## Citation +Please cite the following paper if you use the 7Scenes dataset: +``` +@inproceedings{shotton2013scene, + title={Scene coordinate regression forests for camera relocalization in {RGB-D} images}, + author={Shotton, Jamie and Glocker, Ben and Zach, Christopher and Izadi, Shahram and Criminisi, Antonio and Fitzgibbon, Andrew}, + booktitle={CVPR}, + year={2013} +} +``` + +Also cite DSAC* if you use dense depth maps with the flag `--use_dense_depth`: +``` +@article{brachmann2020dsacstar, + title={Visual Camera Re-Localization from {RGB} and {RGB-D} Images Using {DSAC}}, + author={Brachmann, Eric and Rother, Carsten}, + journal={TPAMI}, + year={2021} +} +``` diff --git a/imcui/hloc/pipelines/7Scenes/__init__.py b/imcui/hloc/pipelines/7Scenes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py b/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py new file mode 100644 index 0000000000000000000000000000000000000000..95dfa461e17de99e0bdde0c52c5f02568c4fbab3 --- /dev/null +++ b/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py @@ -0,0 +1,134 @@ +from pathlib import Path + +import numpy as np +import PIL.Image +import pycolmap +import torch +from tqdm import tqdm + +from ...utils.read_write_model import read_model, write_model + + +def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera): + assert len(depth) == len(p2D) + p2D_norm = np.stack(pycolmap.Camera(camera._asdict()).image_to_world(p2D)) + p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1) + p3D_c = p2D_h * depth[:, None] + p3D_w = (p3D_c - t_w2c) @ R_w2c + return p3D_w + + +def interpolate_depth(depth, kp): + h, w = depth.shape + kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 + assert np.all(kp > -1) and np.all(kp < 1) + depth = torch.from_numpy(depth)[None, None] + kp = torch.from_numpy(kp)[None, None] + grid_sample = torch.nn.functional.grid_sample + + # To maximize the number of points that have depth: + # do bilinear interpolation first and then nearest for the remaining points + interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[0, :, 0] + interp_nn = torch.nn.functional.grid_sample( + depth, kp, align_corners=True, mode="nearest" + )[0, :, 0] + interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) + valid = ~torch.any(torch.isnan(interp), 0) + + interp_depth = interp.T.numpy().flatten() + valid = valid.numpy() + return interp_depth, valid + + +def image_path_to_rendered_depth_path(image_name): + parts = image_name.split("/") + name = "_".join(["".join(parts[0].split("-")), parts[1]]) + name = name.replace("color", "pose") + name = name.replace("png", "depth.tiff") + return name + + +def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1): + p3D = (p3D @ R.T) + t + visible = p3D[:, -1] >= eps # keep points in front of the camera + p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps) + p2D = np.stack(pycolmap.Camera(camera._asdict()).world_to_image(p2D_norm)) + size = np.array([camera.width - pad - 1, camera.height - pad - 1]) + valid = np.all((p2D >= pad) & (p2D <= size), -1) + valid &= visible + return p2D[valid], valid + + +def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path): + cameras, images, points3D = read_model(sfm_path) + for imgid, img in tqdm(images.items()): + image_name = img.name + depth_name = image_path_to_rendered_depth_path(image_name) + + depth = PIL.Image.open(Path(depth_folder_path) / depth_name) + depth = np.array(depth).astype("float64") + depth = depth / 1000.0 # mm to meter + depth[(depth == 0.0) | (depth > 1000.0)] = np.nan + + R_w2c, t_w2c = img.qvec2rotmat(), img.tvec + camera = cameras[img.camera_id] + p3D_ids = img.point3D_ids + p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0) + + p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera) + invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected] + interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds) + scs = scene_coordinates( + p2Ds[valids_backprojected], + R_w2c, + t_w2c, + interp_depth[valids_backprojected], + camera, + ) + invalid_p3D_ids = np.append( + invalid_p3D_ids, + p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected], + ) + for p3did in invalid_p3D_ids: + if p3did == -1: + continue + else: + obs_imgids = points3D[p3did].image_ids + invalid_imgids = list(np.where(obs_imgids == img.id)[0]) + points3D[p3did] = points3D[p3did]._replace( + image_ids=np.delete(obs_imgids, invalid_imgids), + point2D_idxs=np.delete( + points3D[p3did].point2D_idxs, invalid_imgids + ), + ) + + new_p3D_ids = p3D_ids.copy() + sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1] + valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool) + valids[~valids_projected] = False + valids[valids_projected] = valids_backprojected + sub_p3D_ids[~valids] = -1 + new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids + img = img._replace(point3D_ids=new_p3D_ids) + + assert len(img.point3D_ids[img.point3D_ids != -1]) == len( + scs + ), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}" + for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]): + points3D[p3did] = points3D[p3did]._replace(xyz=scs[i]) + images[imgid] = img + + output_path.mkdir(parents=True, exist_ok=True) + write_model(cameras, images, points3D, output_path) + + +if __name__ == "__main__": + dataset = Path("datasets/7scenes") + outputs = Path("outputs/7Scenes") + + SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"] + for scene in SCENES: + sfm_path = outputs / scene / "sfm_superpoint+superglue" + depth_path = dataset / f"depth/7scenes_{scene}/train/depth" + output_path = outputs / scene / "sfm_superpoint+superglue+depth" + correct_sfm_with_gt_depth(sfm_path, depth_path, output_path) diff --git a/imcui/hloc/pipelines/7Scenes/pipeline.py b/imcui/hloc/pipelines/7Scenes/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc28c6d29828ddbe4b9efaf1678be624a998819 --- /dev/null +++ b/imcui/hloc/pipelines/7Scenes/pipeline.py @@ -0,0 +1,139 @@ +import argparse +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + triangulation, +) +from ..Cambridge.utils import create_query_list_with_intrinsics, evaluate +from .create_gt_sfm import correct_sfm_with_gt_depth +from .utils import create_reference_sfm + +SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"] + + +def run_scene( + images, + gt_dir, + retrieval, + outputs, + results, + num_covis, + use_dense_depth, + depth_dir=None, +): + outputs.mkdir(exist_ok=True, parents=True) + ref_sfm_sift = outputs / "sfm_sift" + ref_sfm = outputs / "sfm_superpoint+superglue" + query_list = outputs / "query_list_with_intrinsics.txt" + + feature_conf = { + "output": "feats-superpoint-n4096-r1024", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + }, + "preprocessing": { + "globs": ["*.color.png"], + "grayscale": True, + "resize_max": 1024, + }, + } + matcher_conf = match_features.confs["superglue"] + matcher_conf["model"]["sinkhorn_iterations"] = 5 + + test_list = gt_dir / "list_test.txt" + create_reference_sfm(gt_dir, ref_sfm_sift, test_list) + create_query_list_with_intrinsics(gt_dir, query_list, test_list) + + features = extract_features.main(feature_conf, images, outputs, as_half=True) + + sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt" + pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + if not (use_dense_depth and ref_sfm.exists()): + triangulation.main( + ref_sfm, ref_sfm_sift, images, sfm_pairs, features, sfm_matches + ) + if use_dense_depth: + assert depth_dir is not None + ref_sfm_fix = outputs / "sfm_superpoint+superglue+depth" + correct_sfm_with_gt_depth(ref_sfm, depth_dir, ref_sfm_fix) + ref_sfm = ref_sfm_fix + + loc_matches = match_features.main( + matcher_conf, retrieval, feature_conf["output"], outputs + ) + + localize_sfm.main( + ref_sfm, + query_list, + retrieval, + features, + loc_matches, + results, + covisibility_clustering=False, + prepend_camera_name=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+") + parser.add_argument("--overwrite", action="store_true") + parser.add_argument( + "--dataset", + type=Path, + default="datasets/7scenes", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/7scenes", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument("--use_dense_depth", action="store_true") + parser.add_argument( + "--num_covis", + type=int, + default=30, + help="Number of image pairs for SfM, default: %(default)s", + ) + args = parser.parse_args() + + gt_dirs = args.dataset / "7scenes_sfm_triangulated/{scene}/triangulated" + retrieval_dirs = args.dataset / "7scenes_densevlad_retrieval_top_10" + + all_results = {} + for scene in args.scenes: + logger.info(f'Working on scene "{scene}".') + results = ( + args.outputs + / scene + / "results_{}.txt".format("dense" if args.use_dense_depth else "sparse") + ) + if args.overwrite or not results.exists(): + run_scene( + args.dataset / scene, + Path(str(gt_dirs).format(scene=scene)), + retrieval_dirs / f"{scene}_top10.txt", + args.outputs / scene, + results, + args.num_covis, + args.use_dense_depth, + depth_dir=args.dataset / f"depth/7scenes_{scene}/train/depth", + ) + all_results[scene] = results + + for scene in args.scenes: + logger.info(f'Evaluate scene "{scene}".') + gt_dir = Path(str(gt_dirs).format(scene=scene)) + evaluate(gt_dir, all_results[scene], gt_dir / "list_test.txt") diff --git a/imcui/hloc/pipelines/7Scenes/utils.py b/imcui/hloc/pipelines/7Scenes/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb021286e550de6fe89e03370c11e3f7d567c5f --- /dev/null +++ b/imcui/hloc/pipelines/7Scenes/utils.py @@ -0,0 +1,34 @@ +import logging + +import numpy as np + +from hloc.utils.read_write_model import read_model, write_model + +logger = logging.getLogger(__name__) + + +def create_reference_sfm(full_model, ref_model, blacklist=None, ext=".bin"): + """Create a new COLMAP model with only training images.""" + logger.info("Creating the reference model.") + ref_model.mkdir(exist_ok=True) + cameras, images, points3D = read_model(full_model, ext) + + if blacklist is not None: + with open(blacklist, "r") as f: + blacklist = f.read().rstrip().split("\n") + + images_ref = dict() + for id_, image in images.items(): + if blacklist and image.name in blacklist: + continue + images_ref[id_] = image + + points3D_ref = dict() + for id_, point3D in points3D.items(): + ref_ids = [i for i in point3D.image_ids if i in images_ref] + if len(ref_ids) == 0: + continue + points3D_ref[id_] = point3D._replace(image_ids=np.array(ref_ids)) + + write_model(cameras, images_ref, points3D_ref, ref_model, ".bin") + logger.info(f"Kept {len(images_ref)} images out of {len(images)}.") diff --git a/imcui/hloc/pipelines/Aachen/README.md b/imcui/hloc/pipelines/Aachen/README.md new file mode 100644 index 0000000000000000000000000000000000000000..57b66d6ad1e5cdb3e74c6c1866d394d487c608d1 --- /dev/null +++ b/imcui/hloc/pipelines/Aachen/README.md @@ -0,0 +1,16 @@ +# Aachen-Day-Night dataset + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/aachen +wget -r -np -nH -R "index.html*,aachen_v1_1.zip" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset +unzip $dataset/images/database_and_query_images.zip -d $dataset +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.Aachen.pipeline +``` diff --git a/imcui/hloc/pipelines/Aachen/__init__.py b/imcui/hloc/pipelines/Aachen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/pipelines/Aachen/pipeline.py b/imcui/hloc/pipelines/Aachen/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e31ce7255ce5e178f505a6fb415b5bbf13b76879 --- /dev/null +++ b/imcui/hloc/pipelines/Aachen/pipeline.py @@ -0,0 +1,109 @@ +import argparse +from pathlib import Path +from pprint import pformat + +from ... import ( + colmap_from_nvm, + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images_upright/" + + outputs = args.outputs # where everything will be saved + sift_sfm = outputs / "sfm_sift" # from which we extract the reference poses + reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build + sfm_pairs = ( + outputs / f"pairs-db-covis{args.num_covis}.txt" + ) # top-k most covisible in SIFT model + loc_pairs = ( + outputs / f"pairs-query-netvlad{args.num_loc}.txt" + ) # top-k retrieved by NetVLAD + results = outputs / f"Aachen_hloc_superpoint+superglue_netvlad{args.num_loc}.txt" + + # list the standard configurations available + logger.info("Configs for feature extractors:\n%s", pformat(extract_features.confs)) + logger.info("Configs for feature matchers:\n%s", pformat(match_features.confs)) + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_aachen"] + matcher_conf = match_features.confs["superglue"] + + features = extract_features.main(feature_conf, images, outputs) + + colmap_from_nvm.main( + dataset / "3D-models/aachen_cvpr2018_db.nvm", + dataset / "3D-models/database_intrinsics.txt", + dataset / "aachen.db", + sift_sfm, + ) + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix="query", + db_model=reference_sfm, + ) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + reference_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + ) # not required with SuperPoint+SuperGlue + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/aachen", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=50, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + run(args) diff --git a/imcui/hloc/pipelines/Aachen_v1_1/README.md b/imcui/hloc/pipelines/Aachen_v1_1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c17e751777b56e36b8633c1eec37ff656f2d3979 --- /dev/null +++ b/imcui/hloc/pipelines/Aachen_v1_1/README.md @@ -0,0 +1,17 @@ +# Aachen-Day-Night dataset v1.1 + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/aachen_v1.1 +wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset +unzip $dataset/images/database_and_query_images.zip -d $dataset +unzip $dataset/aachen_v1_1.zip -d $dataset +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.Aachen_v1_1.pipeline +``` diff --git a/imcui/hloc/pipelines/Aachen_v1_1/__init__.py b/imcui/hloc/pipelines/Aachen_v1_1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py b/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0753604624d31984952942ab5b297a247e4d5123 --- /dev/null +++ b/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py @@ -0,0 +1,104 @@ +import argparse +from pathlib import Path +from pprint import pformat + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images_upright/" + sift_sfm = dataset / "3D-models/aachen_v_1_1" + + outputs = args.outputs # where everything will be saved + reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build + sfm_pairs = ( + outputs / f"pairs-db-covis{args.num_covis}.txt" + ) # top-k most covisible in SIFT model + loc_pairs = ( + outputs / f"pairs-query-netvlad{args.num_loc}.txt" + ) # top-k retrieved by NetVLAD + results = ( + outputs / f"Aachen-v1.1_hloc_superpoint+superglue_netvlad{args.num_loc}.txt" + ) + + # list the standard configurations available + logger.info("Configs for feature extractors:\n%s", pformat(extract_features.confs)) + logger.info("Configs for feature matchers:\n%s", pformat(match_features.confs)) + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_max"] + matcher_conf = match_features.confs["superglue"] + + features = extract_features.main(feature_conf, images, outputs) + + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix="query", + db_model=reference_sfm, + ) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + reference_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + ) # not required with SuperPoint+SuperGlue + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/aachen_v1.1", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen_v1.1", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=50, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + run(args) diff --git a/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py b/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0a769897604fa9838106c7a38ba585ceeefe5c --- /dev/null +++ b/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py @@ -0,0 +1,104 @@ +import argparse +from pathlib import Path +from pprint import pformat + +from ... import ( + extract_features, + localize_sfm, + logger, + match_dense, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images_upright/" + sift_sfm = dataset / "3D-models/aachen_v_1_1" + + outputs = args.outputs # where everything will be saved + outputs.mkdir() + reference_sfm = outputs / "sfm_loftr" # the SfM model we will build + sfm_pairs = ( + outputs / f"pairs-db-covis{args.num_covis}.txt" + ) # top-k most covisible in SIFT model + loc_pairs = ( + outputs / f"pairs-query-netvlad{args.num_loc}.txt" + ) # top-k retrieved by NetVLAD + results = outputs / f"Aachen-v1.1_hloc_loftr_netvlad{args.num_loc}.txt" + + # list the standard configurations available + logger.info("Configs for dense feature matchers:\n%s", pformat(match_dense.confs)) + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + matcher_conf = match_dense.confs["loftr_aachen"] + + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + features, sfm_matches = match_dense.main( + matcher_conf, sfm_pairs, images, outputs, max_kps=8192, overwrite=False + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix="query", + db_model=reference_sfm, + ) + features, loc_matches = match_dense.main( + matcher_conf, + loc_pairs, + images, + outputs, + features=features, + max_kps=None, + matches=sfm_matches, + ) + + localize_sfm.main( + reference_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + ) # not required with loftr + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/aachen_v1.1", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen_v1.1", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=50, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() diff --git a/imcui/hloc/pipelines/CMU/README.md b/imcui/hloc/pipelines/CMU/README.md new file mode 100644 index 0000000000000000000000000000000000000000..566ba352c53ada2a13dce21c8ec1041b56969d03 --- /dev/null +++ b/imcui/hloc/pipelines/CMU/README.md @@ -0,0 +1,16 @@ +# Extended CMU Seasons dataset + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/cmu_extended +wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Extended-CMU-Seasons/ -P $dataset +for slice in $dataset/*.tar; do tar -xf $slice -C $dataset && rm $slice; done +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.CMU.pipeline +``` diff --git a/imcui/hloc/pipelines/CMU/__init__.py b/imcui/hloc/pipelines/CMU/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/pipelines/CMU/pipeline.py b/imcui/hloc/pipelines/CMU/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..4706a05c6aef134dc244419501b22a2ba95ede04 --- /dev/null +++ b/imcui/hloc/pipelines/CMU/pipeline.py @@ -0,0 +1,133 @@ +import argparse +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) + +TEST_SLICES = [2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18, 19, 20, 21] + + +def generate_query_list(dataset, path, slice_): + cameras = {} + with open(dataset / "intrinsics.txt", "r") as f: + for line in f.readlines(): + if line[0] == "#" or line == "\n": + continue + data = line.split() + cameras[data[0]] = data[1:] + assert len(cameras) == 2 + + queries = dataset / f"{slice_}/test-images-{slice_}.txt" + with open(queries, "r") as f: + queries = [q.rstrip("\n") for q in f.readlines()] + + out = [[q] + cameras[q.split("_")[2]] for q in queries] + with open(path, "w") as f: + f.write("\n".join(map(" ".join, out))) + + +def run_slice(slice_, root, outputs, num_covis, num_loc): + dataset = root / slice_ + ref_images = dataset / "database" + query_images = dataset / "query" + sift_sfm = dataset / "sparse" + + outputs = outputs / slice_ + outputs.mkdir(exist_ok=True, parents=True) + query_list = dataset / "queries_with_intrinsics.txt" + sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt" + loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt" + ref_sfm = outputs / "sfm_superpoint+superglue" + results = outputs / f"CMU_hloc_superpoint+superglue_netvlad{num_loc}.txt" + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_aachen"] + matcher_conf = match_features.confs["superglue"] + + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=num_covis) + features = extract_features.main(feature_conf, ref_images, outputs, as_half=True) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + triangulation.main(ref_sfm, sift_sfm, ref_images, sfm_pairs, features, sfm_matches) + + generate_query_list(root, query_list, slice_) + global_descriptors = extract_features.main(retrieval_conf, ref_images, outputs) + global_descriptors = extract_features.main(retrieval_conf, query_images, outputs) + pairs_from_retrieval.main( + global_descriptors, loc_pairs, num_loc, query_list=query_list, db_model=ref_sfm + ) + + features = extract_features.main(feature_conf, query_images, outputs, as_half=True) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + ref_sfm, + dataset / "queries/*_time_queries_with_intrinsics.txt", + loc_pairs, + features, + loc_matches, + results, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--slices", + type=str, + default="*", + help="a single number, an interval (e.g. 2-6), " + "or a Python-style list or int (e.g. [2, 3, 4]", + ) + parser.add_argument( + "--dataset", + type=Path, + default="datasets/cmu_extended", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/aachen_extended", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=10, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + + if args.slice == "*": + slices = TEST_SLICES + if "-" in args.slices: + min_, max_ = args.slices.split("-") + slices = list(range(int(min_), int(max_) + 1)) + else: + slices = eval(args.slices) + if isinstance(slices, int): + slices = [slices] + + for slice_ in slices: + logger.info("Working on slice %s.", slice_) + run_slice( + f"slice{slice_}", args.dataset, args.outputs, args.num_covis, args.num_loc + ) diff --git a/imcui/hloc/pipelines/Cambridge/README.md b/imcui/hloc/pipelines/Cambridge/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d5ae07b71c48a98fa9235f0dfb0234c3c18c74c6 --- /dev/null +++ b/imcui/hloc/pipelines/Cambridge/README.md @@ -0,0 +1,47 @@ +# Cambridge Landmarks dataset + +## Installation + +Download the dataset from the [PoseNet project page](http://mi.eng.cam.ac.uk/projects/relocalisation/): +```bash +export dataset=datasets/cambridge +export scenes=( "KingsCollege" "OldHospital" "StMarysChurch" "ShopFacade" "GreatCourt" ) +export IDs=( "251342" "251340" "251294" "251336" "251291" ) +for i in "${!scenes[@]}"; do +wget https://www.repository.cam.ac.uk/bitstream/handle/1810/${IDs[i]}/${scenes[i]}.zip -P $dataset \ +&& unzip $dataset/${scenes[i]}.zip -d $dataset && rm $dataset/${scenes[i]}.zip; done +``` + +Download the SIFT SfM models, courtesy of Torsten Sattler: +```bash +export fileid=1esqzZ1zEQlzZVic-H32V6kkZvc4NeS15 +export filename=$dataset/CambridgeLandmarks_Colmap_Retriangulated_1024px.zip +wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$fileid" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$fileid" -O $filename && rm -rf /tmp/cookies.txt +unzip $filename -d $dataset +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.Cambridge.pipeline +``` + +## Results +We report the median error in translation/rotation in cm/deg over all scenes: +| Method \ Scene | Court | King's | Hospital | Shop | St. Mary's | +| ------------------------ | --------------- | --------------- | --------------- | -------------- | -------------- | +| Active Search | 24/0.13 | 13/0.22 | 20/0.36 | **4**/0.21 | 8/0.25 | +| DSAC* | 49/0.3 | 15/0.3 | 21/0.4 | 5/0.3 | 13/0.4 | +| **SuperPoint+SuperGlue** | **17**/**0.11** | **12**/**0.21** | **14**/**0.30** | **4**/**0.19** | **7**/**0.22** | + +## Citation + +Please cite the following paper if you use the Cambridge Landmarks dataset: +``` +@inproceedings{kendall2015posenet, + title={{PoseNet}: A convolutional network for real-time {6-DoF} camera relocalization}, + author={Kendall, Alex and Grimes, Matthew and Cipolla, Roberto}, + booktitle={ICCV}, + year={2015} +} +``` diff --git a/imcui/hloc/pipelines/Cambridge/__init__.py b/imcui/hloc/pipelines/Cambridge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/pipelines/Cambridge/pipeline.py b/imcui/hloc/pipelines/Cambridge/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..3a676e5af411858f2459cb7f58f777b30be67d29 --- /dev/null +++ b/imcui/hloc/pipelines/Cambridge/pipeline.py @@ -0,0 +1,140 @@ +import argparse +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + logger, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) +from .utils import create_query_list_with_intrinsics, evaluate, scale_sfm_images + +SCENES = ["KingsCollege", "OldHospital", "ShopFacade", "StMarysChurch", "GreatCourt"] + + +def run_scene(images, gt_dir, outputs, results, num_covis, num_loc): + ref_sfm_sift = gt_dir / "model_train" + test_list = gt_dir / "list_query.txt" + + outputs.mkdir(exist_ok=True, parents=True) + ref_sfm = outputs / "sfm_superpoint+superglue" + ref_sfm_scaled = outputs / "sfm_sift_scaled" + query_list = outputs / "query_list_with_intrinsics.txt" + sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt" + loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt" + + feature_conf = { + "output": "feats-superpoint-n4096-r1024", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + }, + } + matcher_conf = match_features.confs["superglue"] + retrieval_conf = extract_features.confs["netvlad"] + + create_query_list_with_intrinsics( + gt_dir / "empty_all", query_list, test_list, ext=".txt", image_dir=images + ) + with open(test_list, "r") as f: + query_seqs = {q.split("/")[0] for q in f.read().rstrip().split("\n")} + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + num_loc, + db_model=ref_sfm_sift, + query_prefix=query_seqs, + ) + + features = extract_features.main(feature_conf, images, outputs, as_half=True) + pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + scale_sfm_images(ref_sfm_sift, ref_sfm_scaled, images) + triangulation.main( + ref_sfm, ref_sfm_scaled, images, sfm_pairs, features, sfm_matches + ) + + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + ref_sfm, + query_list, + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + prepend_camera_name=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+") + parser.add_argument("--overwrite", action="store_true") + parser.add_argument( + "--dataset", + type=Path, + default="datasets/cambridge", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/cambridge", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=10, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() + + gt_dirs = args.dataset / "CambridgeLandmarks_Colmap_Retriangulated_1024px" + + all_results = {} + for scene in args.scenes: + logger.info(f'Working on scene "{scene}".') + results = args.outputs / scene / "results.txt" + if args.overwrite or not results.exists(): + run_scene( + args.dataset / scene, + gt_dirs / scene, + args.outputs / scene, + results, + args.num_covis, + args.num_loc, + ) + all_results[scene] = results + + for scene in args.scenes: + logger.info(f'Evaluate scene "{scene}".') + evaluate( + gt_dirs / scene / "empty_all", + all_results[scene], + gt_dirs / scene / "list_query.txt", + ext=".txt", + ) diff --git a/imcui/hloc/pipelines/Cambridge/utils.py b/imcui/hloc/pipelines/Cambridge/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36460f067369065668837fa317b1c0f7047e9203 --- /dev/null +++ b/imcui/hloc/pipelines/Cambridge/utils.py @@ -0,0 +1,145 @@ +import logging + +import cv2 +import numpy as np + +from hloc.utils.read_write_model import ( + qvec2rotmat, + read_cameras_binary, + read_cameras_text, + read_images_binary, + read_images_text, + read_model, + write_model, +) + +logger = logging.getLogger(__name__) + + +def scale_sfm_images(full_model, scaled_model, image_dir): + """Duplicate the provided model and scale the camera intrinsics so that + they match the original image resolution - makes everything easier. + """ + logger.info("Scaling the COLMAP model to the original image size.") + scaled_model.mkdir(exist_ok=True) + cameras, images, points3D = read_model(full_model) + + scaled_cameras = {} + for id_, image in images.items(): + name = image.name + img = cv2.imread(str(image_dir / name)) + assert img is not None, image_dir / name + h, w = img.shape[:2] + + cam_id = image.camera_id + if cam_id in scaled_cameras: + assert scaled_cameras[cam_id].width == w + assert scaled_cameras[cam_id].height == h + continue + + camera = cameras[cam_id] + assert camera.model == "SIMPLE_RADIAL" + sx = w / camera.width + sy = h / camera.height + assert sx == sy, (sx, sy) + scaled_cameras[cam_id] = camera._replace( + width=w, height=h, params=camera.params * np.array([sx, sx, sy, 1.0]) + ) + + write_model(scaled_cameras, images, points3D, scaled_model) + + +def create_query_list_with_intrinsics( + model, out, list_file=None, ext=".bin", image_dir=None +): + """Create a list of query images with intrinsics from the colmap model.""" + if ext == ".bin": + images = read_images_binary(model / "images.bin") + cameras = read_cameras_binary(model / "cameras.bin") + else: + images = read_images_text(model / "images.txt") + cameras = read_cameras_text(model / "cameras.txt") + + name2id = {image.name: i for i, image in images.items()} + if list_file is None: + names = list(name2id) + else: + with open(list_file, "r") as f: + names = f.read().rstrip().split("\n") + data = [] + for name in names: + image = images[name2id[name]] + camera = cameras[image.camera_id] + w, h, params = camera.width, camera.height, camera.params + + if image_dir is not None: + # Check the original image size and rescale the camera intrinsics + img = cv2.imread(str(image_dir / name)) + assert img is not None, image_dir / name + h_orig, w_orig = img.shape[:2] + assert camera.model == "SIMPLE_RADIAL" + sx = w_orig / w + sy = h_orig / h + assert sx == sy, (sx, sy) + w, h = w_orig, h_orig + params = params * np.array([sx, sx, sy, 1.0]) + + p = [name, camera.model, w, h] + params.tolist() + data.append(" ".join(map(str, p))) + with open(out, "w") as f: + f.write("\n".join(data)) + + +def evaluate(model, results, list_file=None, ext=".bin", only_localized=False): + predictions = {} + with open(results, "r") as f: + for data in f.read().rstrip().split("\n"): + data = data.split() + name = data[0] + q, t = np.split(np.array(data[1:], float), [4]) + predictions[name] = (qvec2rotmat(q), t) + if ext == ".bin": + images = read_images_binary(model / "images.bin") + else: + images = read_images_text(model / "images.txt") + name2id = {image.name: i for i, image in images.items()} + + if list_file is None: + test_names = list(name2id) + else: + with open(list_file, "r") as f: + test_names = f.read().rstrip().split("\n") + + errors_t = [] + errors_R = [] + for name in test_names: + if name not in predictions: + if only_localized: + continue + e_t = np.inf + e_R = 180.0 + else: + image = images[name2id[name]] + R_gt, t_gt = image.qvec2rotmat(), image.tvec + R, t = predictions[name] + e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0) + cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1.0, 1.0) + e_R = np.rad2deg(np.abs(np.arccos(cos))) + errors_t.append(e_t) + errors_R.append(e_R) + + errors_t = np.array(errors_t) + errors_R = np.array(errors_R) + + med_t = np.median(errors_t) + med_R = np.median(errors_R) + out = f"Results for file {results.name}:" + out += f"\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg" + + out += "\nPercentage of test images localized within:" + threshs_t = [0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5.0] + threshs_R = [1.0, 2.0, 3.0, 5.0, 2.0, 5.0, 10.0] + for th_t, th_R in zip(threshs_t, threshs_R): + ratio = np.mean((errors_t < th_t) & (errors_R < th_R)) + out += f"\n\t{th_t*100:.0f}cm, {th_R:.0f}deg : {ratio*100:.2f}%" + logger.info(out) diff --git a/imcui/hloc/pipelines/RobotCar/README.md b/imcui/hloc/pipelines/RobotCar/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9881d153d4930cf32b5481ecd4fa2c900fa58c8c --- /dev/null +++ b/imcui/hloc/pipelines/RobotCar/README.md @@ -0,0 +1,16 @@ +# RobotCar Seasons dataset + +## Installation + +Download the dataset from [visuallocalization.net](https://www.visuallocalization.net): +```bash +export dataset=datasets/robotcar +wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/RobotCar-Seasons/ -P $dataset +for condition in $dataset/images/*.zip; do unzip condition -d $dataset/images/; done +``` + +## Pipeline + +```bash +python3 -m hloc.pipelines.RobotCar.pipeline +``` diff --git a/imcui/hloc/pipelines/RobotCar/__init__.py b/imcui/hloc/pipelines/RobotCar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py b/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py new file mode 100644 index 0000000000000000000000000000000000000000..e90ed72b5391990d26961b5acfaaada6517ac191 --- /dev/null +++ b/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py @@ -0,0 +1,176 @@ +import argparse +import logging +import sqlite3 +from collections import defaultdict +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from ...colmap_from_nvm import ( + camera_center_to_translation, + recover_database_images_and_ids, +) +from ...utils.read_write_model import ( + CAMERA_MODEL_IDS, + Camera, + Image, + Point3D, + write_model, +) + +logger = logging.getLogger(__name__) + + +def read_nvm_model(nvm_path, database_path, image_ids, camera_ids, skip_points=False): + # Extract the intrinsics from the db file instead of the NVM model + db = sqlite3.connect(str(database_path)) + ret = db.execute("SELECT camera_id, model, width, height, params FROM cameras;") + cameras = {} + for camera_id, camera_model, width, height, params in ret: + params = np.fromstring(params, dtype=np.double).reshape(-1) + camera_model = CAMERA_MODEL_IDS[camera_model] + assert len(params) == camera_model.num_params, ( + len(params), + camera_model.num_params, + ) + camera = Camera( + id=camera_id, + model=camera_model.model_name, + width=int(width), + height=int(height), + params=params, + ) + cameras[camera_id] = camera + + nvm_f = open(nvm_path, "r") + line = nvm_f.readline() + while line == "\n" or line.startswith("NVM_V3"): + line = nvm_f.readline() + num_images = int(line) + # assert num_images == len(cameras), (num_images, len(cameras)) + + logger.info(f"Reading {num_images} images...") + image_idx_to_db_image_id = [] + image_data = [] + i = 0 + while i < num_images: + line = nvm_f.readline() + if line == "\n": + continue + data = line.strip("\n").lstrip("./").split(" ") + image_data.append(data) + image_idx_to_db_image_id.append(image_ids[data[0]]) + i += 1 + + line = nvm_f.readline() + while line == "\n": + line = nvm_f.readline() + num_points = int(line) + + if skip_points: + logger.info(f"Skipping {num_points} points.") + num_points = 0 + else: + logger.info(f"Reading {num_points} points...") + points3D = {} + image_idx_to_keypoints = defaultdict(list) + i = 0 + pbar = tqdm(total=num_points, unit="pts") + while i < num_points: + line = nvm_f.readline() + if line == "\n": + continue + + data = line.strip("\n").split(" ") + x, y, z, r, g, b, num_observations = data[:7] + obs_image_ids, point2D_idxs = [], [] + for j in range(int(num_observations)): + s = 7 + 4 * j + img_index, kp_index, kx, ky = data[s : s + 4] + image_idx_to_keypoints[int(img_index)].append( + (int(kp_index), float(kx), float(ky), i) + ) + db_image_id = image_idx_to_db_image_id[int(img_index)] + obs_image_ids.append(db_image_id) + point2D_idxs.append(kp_index) + + point = Point3D( + id=i, + xyz=np.array([x, y, z], float), + rgb=np.array([r, g, b], int), + error=1.0, # fake + image_ids=np.array(obs_image_ids, int), + point2D_idxs=np.array(point2D_idxs, int), + ) + points3D[i] = point + + i += 1 + pbar.update(1) + pbar.close() + + logger.info("Parsing image data...") + images = {} + for i, data in enumerate(image_data): + # Skip the focal length. Skip the distortion and terminal 0. + name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data + qvec = np.array([qw, qx, qy, qz], float) + c = np.array([cx, cy, cz], float) + t = camera_center_to_translation(c, qvec) + + if i in image_idx_to_keypoints: + # NVM only stores triangulated 2D keypoints: add dummy ones + keypoints = image_idx_to_keypoints[i] + point2D_idxs = np.array([d[0] for d in keypoints]) + tri_xys = np.array([[x, y] for _, x, y, _ in keypoints]) + tri_ids = np.array([i for _, _, _, i in keypoints]) + + num_2Dpoints = max(point2D_idxs) + 1 + xys = np.zeros((num_2Dpoints, 2), float) + point3D_ids = np.full(num_2Dpoints, -1, int) + xys[point2D_idxs] = tri_xys + point3D_ids[point2D_idxs] = tri_ids + else: + xys = np.zeros((0, 2), float) + point3D_ids = np.full(0, -1, int) + + image_id = image_ids[name] + image = Image( + id=image_id, + qvec=qvec, + tvec=t, + camera_id=camera_ids[name], + name=name.replace("png", "jpg"), # some hack required for RobotCar + xys=xys, + point3D_ids=point3D_ids, + ) + images[image_id] = image + + return cameras, images, points3D + + +def main(nvm, database, output, skip_points=False): + assert nvm.exists(), nvm + assert database.exists(), database + + image_ids, camera_ids = recover_database_images_and_ids(database) + + logger.info("Reading the NVM model...") + model = read_nvm_model( + nvm, database, image_ids, camera_ids, skip_points=skip_points + ) + + logger.info("Writing the COLMAP model...") + output.mkdir(exist_ok=True, parents=True) + write_model(*model, path=str(output), ext=".bin") + logger.info("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--nvm", required=True, type=Path) + parser.add_argument("--database", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--skip_points", action="store_true") + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/hloc/pipelines/RobotCar/pipeline.py b/imcui/hloc/pipelines/RobotCar/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7ee314480d09b2200b9ff1992a3217e24bea2f --- /dev/null +++ b/imcui/hloc/pipelines/RobotCar/pipeline.py @@ -0,0 +1,143 @@ +import argparse +import glob +from pathlib import Path + +from ... import ( + extract_features, + localize_sfm, + match_features, + pairs_from_covisibility, + pairs_from_retrieval, + triangulation, +) +from . import colmap_from_nvm + +CONDITIONS = [ + "dawn", + "dusk", + "night", + "night-rain", + "overcast-summer", + "overcast-winter", + "rain", + "snow", + "sun", +] + + +def generate_query_list(dataset, image_dir, path): + h, w = 1024, 1024 + intrinsics_filename = "intrinsics/{}_intrinsics.txt" + cameras = {} + for side in ["left", "right", "rear"]: + with open(dataset / intrinsics_filename.format(side), "r") as f: + fx = f.readline().split()[1] + fy = f.readline().split()[1] + cx = f.readline().split()[1] + cy = f.readline().split()[1] + assert fx == fy + params = ["SIMPLE_RADIAL", w, h, fx, cx, cy, 0.0] + cameras[side] = [str(p) for p in params] + + queries = glob.glob((image_dir / "**/*.jpg").as_posix(), recursive=True) + queries = [ + Path(q).relative_to(image_dir.parents[0]).as_posix() for q in sorted(queries) + ] + + out = [[q] + cameras[Path(q).parent.name] for q in queries] + with open(path, "w") as f: + f.write("\n".join(map(" ".join, out))) + + +def run(args): + # Setup the paths + dataset = args.dataset + images = dataset / "images/" + + outputs = args.outputs # where everything will be saved + outputs.mkdir(exist_ok=True, parents=True) + query_list = outputs / "{condition}_queries_with_intrinsics.txt" + sift_sfm = outputs / "sfm_sift" + reference_sfm = outputs / "sfm_superpoint+superglue" + sfm_pairs = outputs / f"pairs-db-covis{args.num_covis}.txt" + loc_pairs = outputs / f"pairs-query-netvlad{args.num_loc}.txt" + results = outputs / f"RobotCar_hloc_superpoint+superglue_netvlad{args.num_loc}.txt" + + # pick one of the configurations for extraction and matching + retrieval_conf = extract_features.confs["netvlad"] + feature_conf = extract_features.confs["superpoint_aachen"] + matcher_conf = match_features.confs["superglue"] + + for condition in CONDITIONS: + generate_query_list( + dataset, images / condition, str(query_list).format(condition=condition) + ) + + features = extract_features.main(feature_conf, images, outputs, as_half=True) + + colmap_from_nvm.main( + dataset / "3D-models/all-merged/all.nvm", + dataset / "3D-models/overcast-reference.db", + sift_sfm, + ) + pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis) + sfm_matches = match_features.main( + matcher_conf, sfm_pairs, feature_conf["output"], outputs + ) + + triangulation.main( + reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches + ) + + global_descriptors = extract_features.main(retrieval_conf, images, outputs) + # TODO: do per location and per camera + pairs_from_retrieval.main( + global_descriptors, + loc_pairs, + args.num_loc, + query_prefix=CONDITIONS, + db_model=reference_sfm, + ) + loc_matches = match_features.main( + matcher_conf, loc_pairs, feature_conf["output"], outputs + ) + + localize_sfm.main( + reference_sfm, + Path(str(query_list).format(condition="*")), + loc_pairs, + features, + loc_matches, + results, + covisibility_clustering=False, + prepend_camera_name=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=Path, + default="datasets/robotcar", + help="Path to the dataset, default: %(default)s", + ) + parser.add_argument( + "--outputs", + type=Path, + default="outputs/robotcar", + help="Path to the output directory, default: %(default)s", + ) + parser.add_argument( + "--num_covis", + type=int, + default=20, + help="Number of image pairs for SfM, default: %(default)s", + ) + parser.add_argument( + "--num_loc", + type=int, + default=20, + help="Number of image pairs for loc, default: %(default)s", + ) + args = parser.parse_args() diff --git a/imcui/hloc/pipelines/__init__.py b/imcui/hloc/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/hloc/reconstruction.py b/imcui/hloc/reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1e7fc09c52cca2935c217e912bb077fe712e05 --- /dev/null +++ b/imcui/hloc/reconstruction.py @@ -0,0 +1,194 @@ +import argparse +import multiprocessing +import shutil +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pycolmap + +from . import logger +from .triangulation import ( + OutputCapture, + estimation_and_geometric_verification, + import_features, + import_matches, + parse_option_args, +) +from .utils.database import COLMAPDatabase + + +def create_empty_db(database_path: Path): + if database_path.exists(): + logger.warning("The database already exists, deleting it.") + database_path.unlink() + logger.info("Creating an empty database...") + db = COLMAPDatabase.connect(database_path) + db.create_tables() + db.commit() + db.close() + + +def import_images( + image_dir: Path, + database_path: Path, + camera_mode: pycolmap.CameraMode, + image_list: Optional[List[str]] = None, + options: Optional[Dict[str, Any]] = None, +): + logger.info("Importing images into the database...") + if options is None: + options = {} + images = list(image_dir.iterdir()) + if len(images) == 0: + raise IOError(f"No images found in {image_dir}.") + with pycolmap.ostream(): + pycolmap.import_images( + database_path, + image_dir, + camera_mode, + image_list=image_list or [], + options=options, + ) + + +def get_image_ids(database_path: Path) -> Dict[str, int]: + db = COLMAPDatabase.connect(database_path) + images = {} + for name, image_id in db.execute("SELECT name, image_id FROM images;"): + images[name] = image_id + db.close() + return images + + +def run_reconstruction( + sfm_dir: Path, + database_path: Path, + image_dir: Path, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + models_path = sfm_dir / "models" + models_path.mkdir(exist_ok=True, parents=True) + logger.info("Running 3D reconstruction...") + if options is None: + options = {} + options = {"num_threads": min(multiprocessing.cpu_count(), 16), **options} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstructions = pycolmap.incremental_mapping( + database_path, image_dir, models_path, options=options + ) + + if len(reconstructions) == 0: + logger.error("Could not reconstruct any model!") + return None + logger.info(f"Reconstructed {len(reconstructions)} model(s).") + + largest_index = None + largest_num_images = 0 + for index, rec in reconstructions.items(): + num_images = rec.num_reg_images() + if num_images > largest_num_images: + largest_index = index + largest_num_images = num_images + assert largest_index is not None + logger.info( + f"Largest model is #{largest_index} " f"with {largest_num_images} images." + ) + + for filename in ["images.bin", "cameras.bin", "points3D.bin"]: + if (sfm_dir / filename).exists(): + (sfm_dir / filename).unlink() + shutil.move(str(models_path / str(largest_index) / filename), str(sfm_dir)) + return reconstructions[largest_index] + + +def main( + sfm_dir: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + camera_mode: pycolmap.CameraMode = pycolmap.CameraMode.AUTO, + verbose: bool = False, + skip_geometric_verification: bool = False, + min_match_score: Optional[float] = None, + image_list: Optional[List[str]] = None, + image_options: Optional[Dict[str, Any]] = None, + mapper_options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / "database.db" + + create_empty_db(database) + import_images(image_dir, database, camera_mode, image_list, image_options) + image_ids = get_image_ids(database) + import_features(image_ids, database, features) + import_matches( + image_ids, + database, + pairs, + matches, + min_match_score, + skip_geometric_verification, + ) + if not skip_geometric_verification: + estimation_and_geometric_verification(database, pairs, verbose) + reconstruction = run_reconstruction( + sfm_dir, database, image_dir, verbose, mapper_options + ) + if reconstruction is not None: + logger.info( + f"Reconstruction statistics:\n{reconstruction.summary()}" + + f"\n\tnum_input_images = {len(image_ids)}" + ) + return reconstruction + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sfm_dir", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + + parser.add_argument( + "--camera_mode", + type=str, + default="AUTO", + choices=list(pycolmap.CameraMode.__members__.keys()), + ) + parser.add_argument("--skip_geometric_verification", action="store_true") + parser.add_argument("--min_match_score", type=float) + parser.add_argument("--verbose", action="store_true") + + parser.add_argument( + "--image_options", + nargs="+", + default=[], + help="List of key=value from {}".format(pycolmap.ImageReaderOptions().todict()), + ) + parser.add_argument( + "--mapper_options", + nargs="+", + default=[], + help="List of key=value from {}".format( + pycolmap.IncrementalMapperOptions().todict() + ), + ) + args = parser.parse_args().__dict__ + + image_options = parse_option_args( + args.pop("image_options"), pycolmap.ImageReaderOptions() + ) + mapper_options = parse_option_args( + args.pop("mapper_options"), pycolmap.IncrementalMapperOptions() + ) + + main(**args, image_options=image_options, mapper_options=mapper_options) diff --git a/imcui/hloc/triangulation.py b/imcui/hloc/triangulation.py new file mode 100644 index 0000000000000000000000000000000000000000..83203c38f4e4a2493b8b1b11773fb2140d76b8bc --- /dev/null +++ b/imcui/hloc/triangulation.py @@ -0,0 +1,311 @@ +import argparse +import contextlib +import io +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pycolmap +from tqdm import tqdm + +from . import logger +from .utils.database import COLMAPDatabase +from .utils.geometry import compute_epipolar_errors +from .utils.io import get_keypoints, get_matches +from .utils.parsers import parse_retrieval + + +class OutputCapture: + def __init__(self, verbose: bool): + self.verbose = verbose + + def __enter__(self): + if not self.verbose: + self.capture = contextlib.redirect_stdout(io.StringIO()) + self.out = self.capture.__enter__() + + def __exit__(self, exc_type, *args): + if not self.verbose: + self.capture.__exit__(exc_type, *args) + if exc_type is not None: + logger.error("Failed with output:\n%s", self.out.getvalue()) + sys.stdout.flush() + + +def create_db_from_model( + reconstruction: pycolmap.Reconstruction, database_path: Path +) -> Dict[str, int]: + if database_path.exists(): + logger.warning("The database already exists, deleting it.") + database_path.unlink() + + db = COLMAPDatabase.connect(database_path) + db.create_tables() + + for i, camera in reconstruction.cameras.items(): + db.add_camera( + camera.model.value, + camera.width, + camera.height, + camera.params, + camera_id=i, + prior_focal_length=True, + ) + + for i, image in reconstruction.images.items(): + db.add_image(image.name, image.camera_id, image_id=i) + + db.commit() + db.close() + return {image.name: i for i, image in reconstruction.images.items()} + + +def import_features( + image_ids: Dict[str, int], database_path: Path, features_path: Path +): + logger.info("Importing features into the database...") + db = COLMAPDatabase.connect(database_path) + + for image_name, image_id in tqdm(image_ids.items()): + keypoints = get_keypoints(features_path, image_name) + keypoints += 0.5 # COLMAP origin + db.add_keypoints(image_id, keypoints) + + db.commit() + db.close() + + +def import_matches( + image_ids: Dict[str, int], + database_path: Path, + pairs_path: Path, + matches_path: Path, + min_match_score: Optional[float] = None, + skip_geometric_verification: bool = False, +): + logger.info("Importing matches into the database...") + + with open(str(pairs_path), "r") as f: + pairs = [p.split() for p in f.readlines()] + + db = COLMAPDatabase.connect(database_path) + + matched = set() + for name0, name1 in tqdm(pairs): + id0, id1 = image_ids[name0], image_ids[name1] + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matches, scores = get_matches(matches_path, name0, name1) + if min_match_score: + matches = matches[scores > min_match_score] + db.add_matches(id0, id1, matches) + matched |= {(id0, id1), (id1, id0)} + + if skip_geometric_verification: + db.add_two_view_geometry(id0, id1, matches) + + db.commit() + db.close() + + +def estimation_and_geometric_verification( + database_path: Path, pairs_path: Path, verbose: bool = False +): + logger.info("Performing geometric verification of the matches...") + with OutputCapture(verbose): + with pycolmap.ostream(): + pycolmap.verify_matches( + database_path, + pairs_path, + options=dict(ransac=dict(max_num_trials=20000, min_inlier_ratio=0.1)), + ) + + +def geometric_verification( + image_ids: Dict[str, int], + reference: pycolmap.Reconstruction, + database_path: Path, + features_path: Path, + pairs_path: Path, + matches_path: Path, + max_error: float = 4.0, +): + logger.info("Performing geometric verification of the matches...") + + pairs = parse_retrieval(pairs_path) + db = COLMAPDatabase.connect(database_path) + + inlier_ratios = [] + matched = set() + for name0 in tqdm(pairs): + id0 = image_ids[name0] + image0 = reference.images[id0] + cam0 = reference.cameras[image0.camera_id] + kps0, noise0 = get_keypoints(features_path, name0, return_uncertainty=True) + noise0 = 1.0 if noise0 is None else noise0 + if len(kps0) > 0: + kps0 = np.stack(cam0.cam_from_img(kps0)) + else: + kps0 = np.zeros((0, 2)) + + for name1 in pairs[name0]: + id1 = image_ids[name1] + image1 = reference.images[id1] + cam1 = reference.cameras[image1.camera_id] + kps1, noise1 = get_keypoints(features_path, name1, return_uncertainty=True) + noise1 = 1.0 if noise1 is None else noise1 + if len(kps1) > 0: + kps1 = np.stack(cam1.cam_from_img(kps1)) + else: + kps1 = np.zeros((0, 2)) + + matches = get_matches(matches_path, name0, name1)[0] + + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matched |= {(id0, id1), (id1, id0)} + + if matches.shape[0] == 0: + db.add_two_view_geometry(id0, id1, matches) + continue + + cam1_from_cam0 = image1.cam_from_world * image0.cam_from_world.inverse() + errors0, errors1 = compute_epipolar_errors( + cam1_from_cam0, kps0[matches[:, 0]], kps1[matches[:, 1]] + ) + valid_matches = np.logical_and( + errors0 <= cam0.cam_from_img_threshold(noise0 * max_error), + errors1 <= cam1.cam_from_img_threshold(noise1 * max_error), + ) + # TODO: We could also add E to the database, but we need + # to reverse the transformations if id0 > id1 in utils/database.py. + db.add_two_view_geometry(id0, id1, matches[valid_matches, :]) + inlier_ratios.append(np.mean(valid_matches)) + logger.info( + "mean/med/min/max valid matches %.2f/%.2f/%.2f/%.2f%%.", + np.mean(inlier_ratios) * 100, + np.median(inlier_ratios) * 100, + np.min(inlier_ratios) * 100, + np.max(inlier_ratios) * 100, + ) + + db.commit() + db.close() + + +def run_triangulation( + model_path: Path, + database_path: Path, + image_dir: Path, + reference_model: pycolmap.Reconstruction, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + model_path.mkdir(parents=True, exist_ok=True) + logger.info("Running 3D triangulation...") + if options is None: + options = {} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstruction = pycolmap.triangulate_points( + reference_model, + database_path, + image_dir, + model_path, + options=options, + ) + return reconstruction + + +def main( + sfm_dir: Path, + reference_model: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + skip_geometric_verification: bool = False, + estimate_two_view_geometries: bool = False, + min_match_score: Optional[float] = None, + verbose: bool = False, + mapper_options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + assert reference_model.exists(), reference_model + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / "database.db" + reference = pycolmap.Reconstruction(reference_model) + + image_ids = create_db_from_model(reference, database) + import_features(image_ids, database, features) + import_matches( + image_ids, + database, + pairs, + matches, + min_match_score, + skip_geometric_verification, + ) + if not skip_geometric_verification: + if estimate_two_view_geometries: + estimation_and_geometric_verification(database, pairs, verbose) + else: + geometric_verification( + image_ids, reference, database, features, pairs, matches + ) + reconstruction = run_triangulation( + sfm_dir, database, image_dir, reference, verbose, mapper_options + ) + logger.info( + "Finished the triangulation with statistics:\n%s", + reconstruction.summary(), + ) + return reconstruction + + +def parse_option_args(args: List[str], default_options) -> Dict[str, Any]: + options = {} + for arg in args: + idx = arg.find("=") + if idx == -1: + raise ValueError("Options format: key1=value1 key2=value2 etc.") + key, value = arg[:idx], arg[idx + 1 :] + if not hasattr(default_options, key): + raise ValueError( + f'Unknown option "{key}", allowed options and default values' + f" for {default_options.summary()}" + ) + value = eval(value) + target_type = type(getattr(default_options, key)) + if not isinstance(value, target_type): + raise ValueError( + f'Incorrect type for option "{key}":' f" {type(value)} vs {target_type}" + ) + options[key] = value + return options + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sfm_dir", type=Path, required=True) + parser.add_argument("--reference_sfm_model", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + + parser.add_argument("--skip_geometric_verification", action="store_true") + parser.add_argument("--min_match_score", type=float) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args().__dict__ + + mapper_options = parse_option_args( + args.pop("mapper_options"), pycolmap.IncrementalMapperOptions() + ) + + main(**args, mapper_options=mapper_options) diff --git a/imcui/hloc/utils/__init__.py b/imcui/hloc/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b030ce404e986f2dcf81cf39640cb8e841e87a --- /dev/null +++ b/imcui/hloc/utils/__init__.py @@ -0,0 +1,12 @@ +import os +import sys +from .. import logger + + +def do_system(cmd, verbose=False): + if verbose: + logger.info(f"Run cmd: `{cmd}`.") + err = os.system(cmd) + if err: + logger.info("Run cmd err.") + sys.exit(err) diff --git a/imcui/hloc/utils/base_model.py b/imcui/hloc/utils/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e6cf3971f8ea8bc6c4bf6081f82c4fd9cc4c22b6 --- /dev/null +++ b/imcui/hloc/utils/base_model.py @@ -0,0 +1,56 @@ +import sys +from abc import ABCMeta, abstractmethod +from torch import nn +from copy import copy +import inspect +from huggingface_hub import hf_hub_download + + +class BaseModel(nn.Module, metaclass=ABCMeta): + default_conf = {} + required_inputs = [] + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + self.conf = conf = {**self.default_conf, **conf} + self.required_inputs = copy(self.required_inputs) + self._init(conf) + sys.stdout.flush() + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + for key in self.required_inputs: + assert key in data, "Missing key {} in data".format(key) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + def _download_model(self, repo_id=None, filename=None, **kwargs): + """Download model from hf hub and return the path.""" + return hf_hub_download( + repo_type="model", + repo_id=repo_id, + filename=filename, + ) + + +def dynamic_load(root, model): + module_path = f"{root.__name__}.{model}" + module = __import__(module_path, fromlist=[""]) + classes = inspect.getmembers(module, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == module_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseModel)] + assert len(classes) == 1, classes + return classes[0][1] + # return getattr(module, 'Model') diff --git a/imcui/hloc/utils/database.py b/imcui/hloc/utils/database.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e5e0b7342677757f4b654c1aaeaa76cfe68187 --- /dev/null +++ b/imcui/hloc/utils/database.py @@ -0,0 +1,412 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +# This script is based on an original implementation by True Price. + +import sqlite3 +import sys + +import numpy as np + +IS_PYTHON3 = sys.version_info[0] >= 3 + +MAX_IMAGE_ID = 2**31 - 1 + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) +""".format(MAX_IMAGE_ID) + +CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ +CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB, + qvec BLOB, + tvec BLOB) +""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) +""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join( + [ + CREATE_CAMERAS_TABLE, + CREATE_IMAGES_TABLE, + CREATE_KEYPOINTS_TABLE, + CREATE_DESCRIPTORS_TABLE, + CREATE_MATCHES_TABLE, + CREATE_TWO_VIEW_GEOMETRIES_TABLE, + CREATE_NAME_INDEX, + ] +) + + +def image_ids_to_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def pair_id_to_image_ids(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID + return image_id1, image_id2 + + +def array_to_blob(array): + if IS_PYTHON3: + return array.tobytes() + else: + return np.getbuffer(array) + + +def blob_to_array(blob, dtype, shape=(-1,)): + if IS_PYTHON3: + return np.fromstring(blob, dtype=dtype).reshape(*shape) + else: + return np.frombuffer(blob, dtype=dtype).reshape(*shape) + + +class COLMAPDatabase(sqlite3.Connection): + @staticmethod + def connect(database_path): + return sqlite3.connect(str(database_path), factory=COLMAPDatabase) + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.create_tables = lambda: self.executescript(CREATE_ALL) + self.create_cameras_table = lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.create_descriptors_table = lambda: self.executescript( + CREATE_DESCRIPTORS_TABLE + ) + self.create_images_table = lambda: self.executescript(CREATE_IMAGES_TABLE) + self.create_two_view_geometries_table = lambda: self.executescript( + CREATE_TWO_VIEW_GEOMETRIES_TABLE + ) + self.create_keypoints_table = lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.create_matches_table = lambda: self.executescript(CREATE_MATCHES_TABLE) + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + def add_camera( + self, model, width, height, params, prior_focal_length=False, camera_id=None + ): + params = np.asarray(params, np.float64) + cursor = self.execute( + "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + ( + camera_id, + model, + width, + height, + array_to_blob(params), + prior_focal_length, + ), + ) + return cursor.lastrowid + + def add_image( + self, + name, + camera_id, + prior_q=np.full(4, np.NaN), + prior_t=np.full(3, np.NaN), + image_id=None, + ): + cursor = self.execute( + "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + image_id, + name, + camera_id, + prior_q[0], + prior_q[1], + prior_q[2], + prior_q[3], + prior_t[0], + prior_t[1], + prior_t[2], + ), + ) + return cursor.lastrowid + + def add_keypoints(self, image_id, keypoints): + assert len(keypoints.shape) == 2 + assert keypoints.shape[1] in [2, 4, 6] + + keypoints = np.asarray(keypoints, np.float32) + self.execute( + "INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),), + ) + + def add_descriptors(self, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + self.execute( + "INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),), + ) + + def add_matches(self, image_id1, image_id2, matches): + assert len(matches.shape) == 2 + assert matches.shape[1] == 2 + + if image_id1 > image_id2: + matches = matches[:, ::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + self.execute( + "INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),), + ) + + def add_two_view_geometry( + self, + image_id1, + image_id2, + matches, + F=np.eye(3), + E=np.eye(3), + H=np.eye(3), + qvec=np.array([1.0, 0.0, 0.0, 0.0]), + tvec=np.zeros(3), + config=2, + ): + assert len(matches.shape) == 2 + assert matches.shape[1] == 2 + + if image_id1 > image_id2: + matches = matches[:, ::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + F = np.asarray(F, dtype=np.float64) + E = np.asarray(E, dtype=np.float64) + H = np.asarray(H, dtype=np.float64) + qvec = np.asarray(qvec, dtype=np.float64) + tvec = np.asarray(tvec, dtype=np.float64) + self.execute( + "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + + matches.shape + + ( + array_to_blob(matches), + config, + array_to_blob(F), + array_to_blob(E), + array_to_blob(H), + array_to_blob(qvec), + array_to_blob(tvec), + ), + ) + + +def example_usage(): + import os + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--database_path", default="database.db") + args = parser.parse_args() + + if os.path.exists(args.database_path): + print("ERROR: database path already exists -- will not modify it.") + return + + # Open the database. + + db = COLMAPDatabase.connect(args.database_path) + + # For convenience, try creating all the tables upfront. + + db.create_tables() + + # Create dummy cameras. + + model1, width1, height1, params1 = ( + 0, + 1024, + 768, + np.array((1024.0, 512.0, 384.0)), + ) + model2, width2, height2, params2 = ( + 2, + 1024, + 768, + np.array((1024.0, 512.0, 384.0, 0.1)), + ) + + camera_id1 = db.add_camera(model1, width1, height1, params1) + camera_id2 = db.add_camera(model2, width2, height2, params2) + + # Create dummy images. + + image_id1 = db.add_image("image1.png", camera_id1) + image_id2 = db.add_image("image2.png", camera_id1) + image_id3 = db.add_image("image3.png", camera_id2) + image_id4 = db.add_image("image4.png", camera_id2) + + # Create dummy keypoints. + # + # Note that COLMAP supports: + # - 2D keypoints: (x, y) + # - 4D keypoints: (x, y, theta, scale) + # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) + + num_keypoints = 1000 + keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) + keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) + + db.add_keypoints(image_id1, keypoints1) + db.add_keypoints(image_id2, keypoints2) + db.add_keypoints(image_id3, keypoints3) + db.add_keypoints(image_id4, keypoints4) + + # Create dummy matches. + + M = 50 + matches12 = np.random.randint(num_keypoints, size=(M, 2)) + matches23 = np.random.randint(num_keypoints, size=(M, 2)) + matches34 = np.random.randint(num_keypoints, size=(M, 2)) + + db.add_matches(image_id1, image_id2, matches12) + db.add_matches(image_id2, image_id3, matches23) + db.add_matches(image_id3, image_id4, matches34) + + # Commit the data to the file. + + db.commit() + + # Read and check cameras. + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id1 + assert model == model1 and width == width1 and height == height1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id2 + assert model == model2 and width == width2 and height == height2 + assert np.allclose(params, params2) + + # Read and check keypoints. + + keypoints = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute("SELECT image_id, data FROM keypoints") + ) + + assert np.allclose(keypoints[image_id1], keypoints1) + assert np.allclose(keypoints[image_id2], keypoints2) + assert np.allclose(keypoints[image_id3], keypoints3) + assert np.allclose(keypoints[image_id4], keypoints4) + + # Read and check matches. + + pair_ids = [ # noqa: F841 + image_ids_to_pair_id(*pair) + for pair in ( + (image_id1, image_id2), + (image_id2, image_id3), + (image_id3, image_id4), + ) + ] + + matches = dict( + (pair_id_to_image_ids(pair_id), blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches") + ) + + assert np.all(matches[(image_id1, image_id2)] == matches12) + assert np.all(matches[(image_id2, image_id3)] == matches23) + assert np.all(matches[(image_id3, image_id4)] == matches34) + + # Clean up. + + db.close() + + if os.path.exists(args.database_path): + os.remove(args.database_path) + + +if __name__ == "__main__": + example_usage() diff --git a/imcui/hloc/utils/geometry.py b/imcui/hloc/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..5995cccda40354f2346ad84fa966614b4ccbfed0 --- /dev/null +++ b/imcui/hloc/utils/geometry.py @@ -0,0 +1,16 @@ +import numpy as np +import pycolmap + + +def to_homogeneous(p): + return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1) + + +def compute_epipolar_errors(j_from_i: pycolmap.Rigid3d, p2d_i, p2d_j): + j_E_i = j_from_i.essential_matrix() + l2d_j = to_homogeneous(p2d_i) @ j_E_i.T + l2d_i = to_homogeneous(p2d_j) @ j_E_i + dist = np.abs(np.sum(to_homogeneous(p2d_i) * l2d_i, axis=1)) + errors_i = dist / np.linalg.norm(l2d_i[:, :2], axis=1) + errors_j = dist / np.linalg.norm(l2d_j[:, :2], axis=1) + return errors_i, errors_j diff --git a/imcui/hloc/utils/io.py b/imcui/hloc/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd55d4c30b41c3754634a164312dc5e8c294274 --- /dev/null +++ b/imcui/hloc/utils/io.py @@ -0,0 +1,77 @@ +from typing import Tuple +from pathlib import Path +import numpy as np +import cv2 +import h5py + +from .parsers import names_to_pair, names_to_pair_old + + +def read_image(path, grayscale=False): + if grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise ValueError(f"Cannot read image {path}.") + if not grayscale and len(image.shape) == 3: + image = image[:, :, ::-1] # BGR to RGB + return image + + +def list_h5_names(path): + names = [] + with h5py.File(str(path), "r", libver="latest") as fd: + + def visit_fn(_, obj): + if isinstance(obj, h5py.Dataset): + names.append(obj.parent.name.strip("/")) + + fd.visititems(visit_fn) + return list(set(names)) + + +def get_keypoints( + path: Path, name: str, return_uncertainty: bool = False +) -> np.ndarray: + with h5py.File(str(path), "r", libver="latest") as hfile: + dset = hfile[name]["keypoints"] + p = dset.__array__() + uncertainty = dset.attrs.get("uncertainty") + if return_uncertainty: + return p, uncertainty + return p + + +def find_pair(hfile: h5py.File, name0: str, name1: str): + pair = names_to_pair(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair(name1, name0) + if pair in hfile: + return pair, True + # older, less efficient format + pair = names_to_pair_old(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair_old(name1, name0) + if pair in hfile: + return pair, True + raise ValueError( + f"Could not find pair {(name0, name1)}... " + "Maybe you matched with a different list of pairs? " + ) + + +def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]: + with h5py.File(str(path), "r", libver="latest") as hfile: + pair, reverse = find_pair(hfile, name0, name1) + matches = hfile[pair]["matches0"].__array__() + scores = hfile[pair]["matching_scores0"].__array__() + idx = np.where(matches != -1)[0] + matches = np.stack([idx, matches[idx]], -1) + if reverse: + matches = np.flip(matches, -1) + scores = scores[idx] + return matches, scores diff --git a/imcui/hloc/utils/parsers.py b/imcui/hloc/utils/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..9407dcf916d67170e3f8d19581041a774a61e84b --- /dev/null +++ b/imcui/hloc/utils/parsers.py @@ -0,0 +1,59 @@ +import logging +from collections import defaultdict +from pathlib import Path + +import numpy as np +import pycolmap + +logger = logging.getLogger(__name__) + + +def parse_image_list(path, with_intrinsics=False): + images = [] + with open(path, "r") as f: + for line in f: + line = line.strip("\n") + if len(line) == 0 or line[0] == "#": + continue + name, *data = line.split() + if with_intrinsics: + model, width, height, *params = data + params = np.array(params, float) + cam = pycolmap.Camera( + model=model, width=int(width), height=int(height), params=params + ) + images.append((name, cam)) + else: + images.append(name) + + assert len(images) > 0 + logger.info(f"Imported {len(images)} images from {path.name}") + return images + + +def parse_image_lists(paths, with_intrinsics=False): + images = [] + files = list(Path(paths.parent).glob(paths.name)) + assert len(files) > 0 + for lfile in files: + images += parse_image_list(lfile, with_intrinsics=with_intrinsics) + return images + + +def parse_retrieval(path): + retrieval = defaultdict(list) + with open(path, "r") as f: + for p in f.read().rstrip("\n").split("\n"): + if len(p) == 0: + continue + q, r = p.split() + retrieval[q].append(r) + return dict(retrieval) + + +def names_to_pair(name0, name1, separator="/"): + return separator.join((name0.replace("/", "-"), name1.replace("/", "-"))) + + +def names_to_pair_old(name0, name1): + return names_to_pair(name0, name1, separator="_") diff --git a/imcui/hloc/utils/read_write_model.py b/imcui/hloc/utils/read_write_model.py new file mode 100644 index 0000000000000000000000000000000000000000..197921ded6d9cad3f365fd68a225822dc5411aee --- /dev/null +++ b/imcui/hloc/utils/read_write_model.py @@ -0,0 +1,588 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +import argparse +import collections +import logging +import os +import struct + +import numpy as np + +logger = logging.getLogger(__name__) + + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, model=model, width=width, height=height, params=params + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ" + ) + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = ( + "# Camera list with one line of data per camera:\n" + + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + + "# Number of cameras: {}\n".format(len(cameras)) + ) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum( + (len(img.point3D_ids) for _, img in images.items()) + ) / len(images) + HEADER = ( + "# Image list with two lines of data per image:\n" + + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ + 0 + ] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum( + (len(pt.image_ids) for _, pt in points3D.items()) + ) / len(points3D) + HEADER = ( + "# 3D point list with one line of data per point:\n" + + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" # noqa: E501 + + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if ( + os.path.isfile(os.path.join(path, "cameras" + ext)) + and os.path.isfile(os.path.join(path, "images" + ext)) + and os.path.isfile(os.path.join(path, "points3D" + ext)) + ): + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + try: + cameras, images, points3D = read_model(os.path.join(path, "model/")) + logger.warning("This SfM file structure was deprecated in hloc v1.1") + return cameras, images, points3D + except FileNotFoundError: + raise FileNotFoundError( + f"Could not find binary or text COLMAP model at {path}" + ) + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="outut model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, images, points3D, path=args.output_model, ext=args.output_format + ) + + +if __name__ == "__main__": + main() diff --git a/imcui/hloc/utils/viz.py b/imcui/hloc/utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..f87b51706652e47e6f8fe7f5f67fc5362a970ecd --- /dev/null +++ b/imcui/hloc/utils/viz.py @@ -0,0 +1,145 @@ +""" +2D visualization primitives based on Matplotlib. + +1) Plot images with `plot_images`. +2) Call `plot_keypoints` or `plot_matches` any number of times. +3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. +""" + +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import numpy as np + + +def cm_RdGn(x): + """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" + x = np.clip(x, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) + return np.clip(c, 0, 1) + + +def plot_images( + imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True, figsize=4.5 +): + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + adaptive: whether the figure size should fit the image aspect ratios. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + if adaptive: + ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H + else: + ratios = [4 / 3] * n + figsize = [sum(ratios) * figsize, figsize] + fig, axs = plt.subplots( + 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) + if n == 1: + axs = [axs] + for i, (img, ax) in enumerate(zip(imgs, axs)): + ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) + ax.set_axis_off() + if titles: + ax.set_title(titles[i]) + fig.tight_layout(pad=pad) + return fig + + +def plot_keypoints(kpts, colors="lime", ps=4): + """Plot keypoints for existing images. + Args: + kpts: list of ndarrays of size (N, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float. + """ + if not isinstance(colors, list): + colors = [colors] * len(kpts) + axes = plt.gcf().axes + try: + for a, k, c in zip(axes, kpts, colors): + a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) + except IndexError: + pass + + +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0): + """Plot matches for a pair of existing images. + Args: + kpts0, kpts1: corresponding keypoints of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + ps: size of the end points (no endpoint if ps=0) + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + for i in range(len(kpts0)): + fig.add_artist( + matplotlib.patches.ConnectionPatch( + xyA=(kpts0[i, 0], kpts0[i, 1]), + coordsA=ax0.transData, + xyB=(kpts1[i, 0], kpts1[i, 1]), + coordsB=ax1.transData, + zorder=1, + color=color[i], + linewidth=lw, + alpha=a, + ) + ) + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + if ps > 0: + ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def add_text( + idx, + text, + pos=(0.01, 0.99), + fs=15, + color="w", + lcolor="k", + lwidth=2, + ha="left", + va="top", +): + ax = plt.gcf().axes[idx] + t = ax.text( + *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes + ) + if lcolor is not None: + t.set_path_effects( + [ + path_effects.Stroke(linewidth=lwidth, foreground=lcolor), + path_effects.Normal(), + ] + ) + + +def save_plot(path, **kw): + """Save the current figure without any white margin.""" + plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) diff --git a/imcui/hloc/utils/viz_3d.py b/imcui/hloc/utils/viz_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..f9fd1b1a02eaee99e061bb392a561ebbc00d93b1 --- /dev/null +++ b/imcui/hloc/utils/viz_3d.py @@ -0,0 +1,203 @@ +""" +3D visualization based on plotly. +Works for a small number of points and cameras, might be slow otherwise. + +1) Initialize a figure with `init_figure` +2) Add 3D points, camera frustums, or both as a pycolmap.Reconstruction + +Written by Paul-Edouard Sarlin and Philipp Lindenberger. +""" + +from typing import Optional + +import numpy as np +import plotly.graph_objects as go +import pycolmap + + +def to_homogeneous(points): + pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype) + return np.concatenate([points, pad], axis=-1) + + +def init_figure(height: int = 800) -> go.Figure: + """Initialize a 3D figure.""" + fig = go.Figure() + axes = dict( + visible=False, + showbackground=False, + showgrid=False, + showline=False, + showticklabels=True, + autorange=True, + ) + fig.update_layout( + template="plotly_dark", + height=height, + scene_camera=dict( + eye=dict(x=0.0, y=-0.1, z=-2), + up=dict(x=0, y=-1.0, z=0), + projection=dict(type="orthographic"), + ), + scene=dict( + xaxis=axes, + yaxis=axes, + zaxis=axes, + aspectmode="data", + dragmode="orbit", + ), + margin=dict(l=0, r=0, b=0, t=0, pad=0), + legend=dict(orientation="h", yanchor="top", y=0.99, xanchor="left", x=0.1), + ) + return fig + + +def plot_points( + fig: go.Figure, + pts: np.ndarray, + color: str = "rgba(255, 0, 0, 1)", + ps: int = 2, + colorscale: Optional[str] = None, + name: Optional[str] = None, +): + """Plot a set of 3D points.""" + x, y, z = pts.T + tr = go.Scatter3d( + x=x, + y=y, + z=z, + mode="markers", + name=name, + legendgroup=name, + marker=dict(size=ps, color=color, line_width=0.0, colorscale=colorscale), + ) + fig.add_trace(tr) + + +def plot_camera( + fig: go.Figure, + R: np.ndarray, + t: np.ndarray, + K: np.ndarray, + color: str = "rgb(0, 0, 255)", + name: Optional[str] = None, + legendgroup: Optional[str] = None, + fill: bool = False, + size: float = 1.0, + text: Optional[str] = None, +): + """Plot a camera frustum from pose and intrinsic matrix.""" + W, H = K[0, 2] * 2, K[1, 2] * 2 + corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]]) + if size is not None: + image_extent = max(size * W / 1024.0, size * H / 1024.0) + world_extent = max(W, H) / (K[0, 0] + K[1, 1]) / 0.5 + scale = 0.5 * image_extent / world_extent + else: + scale = 1.0 + corners = to_homogeneous(corners) @ np.linalg.inv(K).T + corners = (corners / 2 * scale) @ R.T + t + legendgroup = legendgroup if legendgroup is not None else name + + x, y, z = np.concatenate(([t], corners)).T + i = [0, 0, 0, 0] + j = [1, 2, 3, 4] + k = [2, 3, 4, 1] + + if fill: + pyramid = go.Mesh3d( + x=x, + y=y, + z=z, + color=color, + i=i, + j=j, + k=k, + legendgroup=legendgroup, + name=name, + showlegend=False, + hovertemplate=text.replace("\n", "
"), + ) + fig.add_trace(pyramid) + + triangles = np.vstack((i, j, k)).T + vertices = np.concatenate(([t], corners)) + tri_points = np.array([vertices[i] for i in triangles.reshape(-1)]) + x, y, z = tri_points.T + + pyramid = go.Scatter3d( + x=x, + y=y, + z=z, + mode="lines", + legendgroup=legendgroup, + name=name, + line=dict(color=color, width=1), + showlegend=False, + hovertemplate=text.replace("\n", "
"), + ) + fig.add_trace(pyramid) + + +def plot_camera_colmap( + fig: go.Figure, + image: pycolmap.Image, + camera: pycolmap.Camera, + name: Optional[str] = None, + **kwargs, +): + """Plot a camera frustum from PyCOLMAP objects""" + world_t_camera = image.cam_from_world.inverse() + plot_camera( + fig, + world_t_camera.rotation.matrix(), + world_t_camera.translation, + camera.calibration_matrix(), + name=name or str(image.image_id), + text=str(image), + **kwargs, + ) + + +def plot_cameras(fig: go.Figure, reconstruction: pycolmap.Reconstruction, **kwargs): + """Plot a camera as a cone with camera frustum.""" + for image_id, image in reconstruction.images.items(): + plot_camera_colmap( + fig, image, reconstruction.cameras[image.camera_id], **kwargs + ) + + +def plot_reconstruction( + fig: go.Figure, + rec: pycolmap.Reconstruction, + max_reproj_error: float = 6.0, + color: str = "rgb(0, 0, 255)", + name: Optional[str] = None, + min_track_length: int = 2, + points: bool = True, + cameras: bool = True, + points_rgb: bool = True, + cs: float = 1.0, +): + # Filter outliers + bbs = rec.compute_bounding_box(0.001, 0.999) + # Filter points, use original reproj error here + p3Ds = [ + p3D + for _, p3D in rec.points3D.items() + if ( + (p3D.xyz >= bbs[0]).all() + and (p3D.xyz <= bbs[1]).all() + and p3D.error <= max_reproj_error + and p3D.track.length() >= min_track_length + ) + ] + xyzs = [p3D.xyz for p3D in p3Ds] + if points_rgb: + pcolor = [p3D.color for p3D in p3Ds] + else: + pcolor = color + if points: + plot_points(fig, np.array(xyzs), color=pcolor, ps=1, name=name) + if cameras: + plot_cameras(fig, rec, color=color, legendgroup=name, size=cs) diff --git a/imcui/hloc/visualization.py b/imcui/hloc/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..456c2ee991efe4895c664f5bbd475c24fa789bf8 --- /dev/null +++ b/imcui/hloc/visualization.py @@ -0,0 +1,178 @@ +import pickle +import random + +import numpy as np +import pycolmap +from matplotlib import cm + +from .utils.io import read_image +from .utils.viz import ( + add_text, + cm_RdGn, + plot_images, + plot_keypoints, + plot_matches, +) + + +def visualize_sfm_2d( + reconstruction, + image_dir, + color_by="visibility", + selected=[], + n=1, + seed=0, + dpi=75, +): + assert image_dir.exists() + if not isinstance(reconstruction, pycolmap.Reconstruction): + reconstruction = pycolmap.Reconstruction(reconstruction) + + if not selected: + image_ids = reconstruction.reg_image_ids() + selected = random.Random(seed).sample(image_ids, min(n, len(image_ids))) + + for i in selected: + image = reconstruction.images[i] + keypoints = np.array([p.xy for p in image.points2D]) + visible = np.array([p.has_point3D() for p in image.points2D]) + + if color_by == "visibility": + color = [(0, 0, 1) if v else (1, 0, 0) for v in visible] + text = f"visible: {np.count_nonzero(visible)}/{len(visible)}" + elif color_by == "track_length": + tl = np.array( + [ + ( + reconstruction.points3D[p.point3D_id].track.length() + if p.has_point3D() + else 1 + ) + for p in image.points2D + ] + ) + max_, med_ = np.max(tl), np.median(tl[tl > 1]) + tl = np.log(tl) + color = cm.jet(tl / tl.max()).tolist() + text = f"max/median track length: {max_}/{med_}" + elif color_by == "depth": + p3ids = [p.point3D_id for p in image.points2D if p.has_point3D()] + z = np.array( + [ + (image.cam_from_world * reconstruction.points3D[j].xyz)[-1] + for j in p3ids + ] + ) + z -= z.min() + color = cm.jet(z / np.percentile(z, 99.9)) + text = f"visible: {np.count_nonzero(visible)}/{len(visible)}" + keypoints = keypoints[visible] + else: + raise NotImplementedError(f"Coloring not implemented: {color_by}.") + + name = image.name + fig = plot_images([read_image(image_dir / name)], dpi=dpi) + plot_keypoints([keypoints], colors=[color], ps=4) + add_text(0, text) + add_text(0, name, pos=(0.01, 0.01), fs=5, lcolor=None, va="bottom") + return fig + + +def visualize_loc( + results, + image_dir, + reconstruction=None, + db_image_dir=None, + selected=[], + n=1, + seed=0, + prefix=None, + **kwargs, +): + assert image_dir.exists() + + with open(str(results) + "_logs.pkl", "rb") as f: + logs = pickle.load(f) + + if not selected: + queries = list(logs["loc"].keys()) + if prefix: + queries = [q for q in queries if q.startswith(prefix)] + selected = random.Random(seed).sample(queries, min(n, len(queries))) + + if reconstruction is not None: + if not isinstance(reconstruction, pycolmap.Reconstruction): + reconstruction = pycolmap.Reconstruction(reconstruction) + + for qname in selected: + loc = logs["loc"][qname] + visualize_loc_from_log( + image_dir, qname, loc, reconstruction, db_image_dir, **kwargs + ) + + +def visualize_loc_from_log( + image_dir, + query_name, + loc, + reconstruction=None, + db_image_dir=None, + top_k_db=2, + dpi=75, +): + q_image = read_image(image_dir / query_name) + if loc.get("covisibility_clustering", False): + # select the first, largest cluster if the localization failed + loc = loc["log_clusters"][loc["best_cluster"] or 0] + + inliers = np.array(loc["PnP_ret"]["inliers"]) + mkp_q = loc["keypoints_query"] + n = len(loc["db"]) + if reconstruction is not None: + # for each pair of query keypoint and its matched 3D point, + # we need to find its corresponding keypoint in each database image + # that observes it. We also count the number of inliers in each. + kp_idxs, kp_to_3D_to_db = loc["keypoint_index_to_db"] + counts = np.zeros(n) + dbs_kp_q_db = [[] for _ in range(n)] + inliers_dbs = [[] for _ in range(n)] + for i, (inl, (p3D_id, db_idxs)) in enumerate(zip(inliers, kp_to_3D_to_db)): + track = reconstruction.points3D[p3D_id].track + track = {el.image_id: el.point2D_idx for el in track.elements} + for db_idx in db_idxs: + counts[db_idx] += inl + kp_db = track[loc["db"][db_idx]] + dbs_kp_q_db[db_idx].append((i, kp_db)) + inliers_dbs[db_idx].append(inl) + else: + # for inloc the database keypoints are already in the logs + assert "keypoints_db" in loc + assert "indices_db" in loc + counts = np.array([np.sum(loc["indices_db"][inliers] == i) for i in range(n)]) + + # display the database images with the most inlier matches + db_sort = np.argsort(-counts) + for db_idx in db_sort[:top_k_db]: + if reconstruction is not None: + db = reconstruction.images[loc["db"][db_idx]] + db_name = db.name + db_kp_q_db = np.array(dbs_kp_q_db[db_idx]) + kp_q = mkp_q[db_kp_q_db[:, 0]] + kp_db = np.array([db.points2D[i].xy for i in db_kp_q_db[:, 1]]) + inliers_db = inliers_dbs[db_idx] + else: + db_name = loc["db"][db_idx] + kp_q = mkp_q[loc["indices_db"] == db_idx] + kp_db = loc["keypoints_db"][loc["indices_db"] == db_idx] + inliers_db = inliers[loc["indices_db"] == db_idx] + + db_image = read_image((db_image_dir or image_dir) / db_name) + color = cm_RdGn(inliers_db).tolist() + text = f"inliers: {sum(inliers_db)}/{len(inliers_db)}" + + plot_images([q_image, db_image], dpi=dpi) + plot_matches(kp_q, kp_db, color, a=0.1) + add_text(0, text) + opts = dict(pos=(0.01, 0.01), fs=5, lcolor=None, va="bottom") + add_text(0, query_name, **opts) + add_text(1, db_name, **opts) diff --git a/imcui/third_party/ALIKE/alike.py b/imcui/third_party/ALIKE/alike.py new file mode 100644 index 0000000000000000000000000000000000000000..303616d52581efce0ae0eb86af70f5ea8984909d --- /dev/null +++ b/imcui/third_party/ALIKE/alike.py @@ -0,0 +1,143 @@ +import logging +import os +import cv2 +import torch +from copy import deepcopy +import torch.nn.functional as F +from torchvision.transforms import ToTensor +import math + +from alnet import ALNet +from soft_detect import DKD +import time + +configs = { + 'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2, + 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-t.pth')}, + 'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2, + 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-s.pth')}, + 'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2, + 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-n.pth')}, + 'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2, + 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-l.pth')}, +} + + +class ALike(ALNet): + def __init__(self, + # ================================== feature encoder + c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128, + single_head: bool = False, + # ================================== detect parameters + radius: int = 2, + top_k: int = 500, scores_th: float = 0.5, + n_limit: int = 5000, + device: str = 'cpu', + model_path: str = '' + ): + super().__init__(c1, c2, c3, c4, dim, single_head) + self.radius = radius + self.top_k = top_k + self.n_limit = n_limit + self.scores_th = scores_th + self.dkd = DKD(radius=self.radius, top_k=self.top_k, + scores_th=self.scores_th, n_limit=self.n_limit) + self.device = device + + if model_path != '': + state_dict = torch.load(model_path, self.device) + self.load_state_dict(state_dict) + self.to(self.device) + self.eval() + logging.info(f'Loaded model parameters from {model_path}') + logging.info( + f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB") + + def extract_dense_map(self, image, ret_dict=False): + # ==================================================== + # check image size, should be integer multiples of 2^5 + # if it is not a integer multiples of 2^5, padding zeros + device = image.device + b, c, h, w = image.shape + h_ = math.ceil(h / 32) * 32 if h % 32 != 0 else h + w_ = math.ceil(w / 32) * 32 if w % 32 != 0 else w + if h_ != h: + h_padding = torch.zeros(b, c, h_ - h, w, device=device) + image = torch.cat([image, h_padding], dim=2) + if w_ != w: + w_padding = torch.zeros(b, c, h_, w_ - w, device=device) + image = torch.cat([image, w_padding], dim=3) + # ==================================================== + + scores_map, descriptor_map = super().forward(image) + + # ==================================================== + if h_ != h or w_ != w: + descriptor_map = descriptor_map[:, :, :h, :w] + scores_map = scores_map[:, :, :h, :w] # Bx1xHxW + # ==================================================== + + # BxCxHxW + descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1) + + if ret_dict: + return {'descriptor_map': descriptor_map, 'scores_map': scores_map, } + else: + return descriptor_map, scores_map + + def forward(self, img, image_size_max=99999, sort=False, sub_pixel=False): + """ + :param img: np.array HxWx3, RGB + :param image_size_max: maximum image size, otherwise, the image will be resized + :param sort: sort keypoints by scores + :param sub_pixel: whether to use sub-pixel accuracy + :return: a dictionary with 'keypoints', 'descriptors', 'scores', and 'time' + """ + H, W, three = img.shape + assert three == 3, "input image shape should be [HxWx3]" + + # ==================== image size constraint + image = deepcopy(img) + max_hw = max(H, W) + if max_hw > image_size_max: + ratio = float(image_size_max / max_hw) + image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio) + + # ==================== convert image to tensor + image = torch.from_numpy(image).to(self.device).to(torch.float32).permute(2, 0, 1)[None] / 255.0 + + # ==================== extract keypoints + start = time.time() + + with torch.no_grad(): + descriptor_map, scores_map = self.extract_dense_map(image) + keypoints, descriptors, scores, _ = self.dkd(scores_map, descriptor_map, + sub_pixel=sub_pixel) + keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0] + keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]]) + + if sort: + indices = torch.argsort(scores, descending=True) + keypoints = keypoints[indices] + descriptors = descriptors[indices] + scores = scores[indices] + + end = time.time() + + return {'keypoints': keypoints.cpu().numpy(), + 'descriptors': descriptors.cpu().numpy(), + 'scores': scores.cpu().numpy(), + 'scores_map': scores_map.cpu().numpy(), + 'time': end - start, } + + +if __name__ == '__main__': + import numpy as np + from thop import profile + + net = ALike(c1=32, c2=64, c3=128, c4=128, dim=128, single_head=False) + + image = np.random.random((640, 480, 3)).astype(np.float32) + flops, params = profile(net, inputs=(image, 9999, False), verbose=False) + print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9)) + print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3)) diff --git a/imcui/third_party/ALIKE/alnet.py b/imcui/third_party/ALIKE/alnet.py new file mode 100644 index 0000000000000000000000000000000000000000..53127063233660c7b96aa15e89aa4a8a1a340dd1 --- /dev/null +++ b/imcui/third_party/ALIKE/alnet.py @@ -0,0 +1,164 @@ +import torch +from torch import nn +from torchvision.models import resnet +from typing import Optional, Callable + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None): + super().__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = resnet.conv3x3(in_channels, out_channels) + self.bn1 = norm_layer(out_channels) + self.conv2 = resnet.conv3x3(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + + def forward(self, x): + x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W + x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W + return x + + +# copied from torchvision\models\resnet.py#27->BasicBlock +class ResBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResBlock, self).__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('ResBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in ResBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = resnet.conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.conv2 = resnet.conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.gate(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.gate(out) + + return out + + +class ALNet(nn.Module): + def __init__(self, c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128, + single_head: bool = True, + ): + super().__init__() + + self.gate = nn.ReLU(inplace=True) + + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4) + + self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d) + + self.block2 = ResBlock(inplanes=c1, planes=c2, stride=1, + downsample=nn.Conv2d(c1, c2, 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d) + self.block3 = ResBlock(inplanes=c2, planes=c3, stride=1, + downsample=nn.Conv2d(c2, c3, 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d) + self.block4 = ResBlock(inplanes=c3, planes=c4, stride=1, + downsample=nn.Conv2d(c3, c4, 1), + gate=self.gate, + norm_layer=nn.BatchNorm2d) + + # ================================== feature aggregation + self.conv1 = resnet.conv1x1(c1, dim // 4) + self.conv2 = resnet.conv1x1(c2, dim // 4) + self.conv3 = resnet.conv1x1(c3, dim // 4) + self.conv4 = resnet.conv1x1(dim, dim // 4) + self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) + self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) + self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True) + + # ================================== detector and descriptor head + self.single_head = single_head + if not self.single_head: + self.convhead1 = resnet.conv1x1(dim, dim) + self.convhead2 = resnet.conv1x1(dim, dim + 1) + + def forward(self, image): + # ================================== feature encoder + x1 = self.block1(image) # B x c1 x H x W + x2 = self.pool2(x1) + x2 = self.block2(x2) # B x c2 x H/2 x W/2 + x3 = self.pool4(x2) + x3 = self.block3(x3) # B x c3 x H/8 x W/8 + x4 = self.pool4(x3) + x4 = self.block4(x4) # B x dim x H/32 x W/32 + + # ================================== feature aggregation + x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W + x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2 + x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8 + x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32 + x2_up = self.upsample2(x2) # B x dim//4 x H x W + x3_up = self.upsample8(x3) # B x dim//4 x H x W + x4_up = self.upsample32(x4) # B x dim//4 x H x W + x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) + + # ================================== detector and descriptor head + if not self.single_head: + x1234 = self.gate(self.convhead1(x1234)) + x = self.convhead2(x1234) # B x dim+1 x H x W + + descriptor_map = x[:, :-1, :, :] + scores_map = torch.sigmoid(x[:, -1, :, :]).unsqueeze(1) + + return scores_map, descriptor_map + + +if __name__ == '__main__': + from thop import profile + + net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True) + + image = torch.randn(1, 3, 640, 480) + flops, params = profile(net, inputs=(image,), verbose=False) + print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9)) + print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3)) diff --git a/imcui/third_party/ALIKE/demo.py b/imcui/third_party/ALIKE/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..9bfbefdd26cfeceefc75f90d1c44a7f922c624a5 --- /dev/null +++ b/imcui/third_party/ALIKE/demo.py @@ -0,0 +1,167 @@ +import copy +import os +import cv2 +import glob +import logging +import argparse +import numpy as np +from tqdm import tqdm +from alike import ALike, configs + + +class ImageLoader(object): + def __init__(self, filepath: str): + self.N = 3000 + if filepath.startswith('camera'): + camera = int(filepath[6:]) + self.cap = cv2.VideoCapture(camera) + if not self.cap.isOpened(): + raise IOError(f"Can't open camera {camera}!") + logging.info(f'Opened camera {camera}') + self.mode = 'camera' + elif os.path.exists(filepath): + if os.path.isfile(filepath): + self.cap = cv2.VideoCapture(filepath) + if not self.cap.isOpened(): + raise IOError(f"Can't open video {filepath}!") + rate = self.cap.get(cv2.CAP_PROP_FPS) + self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 + duration = self.N / rate + logging.info(f'Opened video {filepath}') + logging.info(f'Frames: {self.N}, FPS: {rate}, Duration: {duration}s') + self.mode = 'video' + else: + self.images = glob.glob(os.path.join(filepath, '*.png')) + \ + glob.glob(os.path.join(filepath, '*.jpg')) + \ + glob.glob(os.path.join(filepath, '*.ppm')) + self.images.sort() + self.N = len(self.images) + logging.info(f'Loading {self.N} images') + self.mode = 'images' + else: + raise IOError('Error filepath (camerax/path of images/path of videos): ', filepath) + + def __getitem__(self, item): + if self.mode == 'camera' or self.mode == 'video': + if item > self.N: + return None + ret, img = self.cap.read() + if not ret: + raise "Can't read image from camera" + if self.mode == 'video': + self.cap.set(cv2.CAP_PROP_POS_FRAMES, item) + elif self.mode == 'images': + filename = self.images[item] + img = cv2.imread(filename) + if img is None: + raise Exception('Error reading image %s' % filename) + return img + + def __len__(self): + return self.N + + +class SimpleTracker(object): + def __init__(self): + self.pts_prev = None + self.desc_prev = None + + def update(self, img, pts, desc): + N_matches = 0 + if self.pts_prev is None: + self.pts_prev = pts + self.desc_prev = desc + + out = copy.deepcopy(img) + for pt1 in pts: + p1 = (int(round(pt1[0])), int(round(pt1[1]))) + cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16) + else: + matches = self.mnn_mather(self.desc_prev, desc) + mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]] + N_matches = len(matches) + + out = copy.deepcopy(img) + for pt1, pt2 in zip(mpts1, mpts2): + p1 = (int(round(pt1[0])), int(round(pt1[1]))) + p2 = (int(round(pt2[0])), int(round(pt2[1]))) + cv2.line(out, p1, p2, (0, 255, 0), lineType=16) + cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16) + + self.pts_prev = pts + self.desc_prev = desc + + return out, N_matches + + def mnn_mather(self, desc1, desc2): + sim = desc1 @ desc2.transpose() + sim[sim < 0.9] = 0 + nn12 = np.argmax(sim, axis=1) + nn21 = np.argmax(sim, axis=0) + ids1 = np.arange(0, sim.shape[0]) + mask = (ids1 == nn21[nn12]) + matches = np.stack([ids1[mask], nn12[mask]]) + return matches.transpose() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='ALike Demo.') + parser.add_argument('input', type=str, default='', + help='Image directory or movie file or "camera0" (for webcam0).') + parser.add_argument('--model', choices=['alike-t', 'alike-s', 'alike-n', 'alike-l'], default="alike-t", + help="The model configuration") + parser.add_argument('--device', type=str, default='cuda', help="Running device (default: cuda).") + parser.add_argument('--top_k', type=int, default=-1, + help='Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)') + parser.add_argument('--scores_th', type=float, default=0.2, + help='Detector score threshold (default: 0.2).') + parser.add_argument('--n_limit', type=int, default=5000, + help='Maximum number of keypoints to be detected (default: 5000).') + parser.add_argument('--no_display', action='store_true', + help='Do not display images to screen. Useful if running remotely (default: False).') + parser.add_argument('--no_sub_pixel', action='store_true', + help='Do not detect sub-pixel keypoints (default: False).') + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + image_loader = ImageLoader(args.input) + model = ALike(**configs[args.model], + device=args.device, + top_k=args.top_k, + scores_th=args.scores_th, + n_limit=args.n_limit) + tracker = SimpleTracker() + + if not args.no_display: + logging.info("Press 'q' to stop!") + cv2.namedWindow(args.model) + + runtime = [] + progress_bar = tqdm(image_loader) + for img in progress_bar: + if img is None: + break + + img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + pred = model(img_rgb, sub_pixel=not args.no_sub_pixel) + kpts = pred['keypoints'] + desc = pred['descriptors'] + runtime.append(pred['time']) + + out, N_matches = tracker.update(img, kpts, desc) + + ave_fps = (1. / np.stack(runtime)).mean() + status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}" + progress_bar.set_description(status) + + if not args.no_display: + cv2.setWindowTitle(args.model, args.model + ': ' + status) + cv2.imshow(args.model, out) + if cv2.waitKey(1) == ord('q'): + break + + logging.info('Finished!') + if not args.no_display: + logging.info('Press any key to exit!') + cv2.waitKey() diff --git a/imcui/third_party/ALIKE/hseq/eval.py b/imcui/third_party/ALIKE/hseq/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..abca625044013a0cd34a518223c32d3ec8abb8a3 --- /dev/null +++ b/imcui/third_party/ALIKE/hseq/eval.py @@ -0,0 +1,162 @@ +import cv2 +import os +from tqdm import tqdm +import torch +import numpy as np +from extract import extract_method + +use_cuda = torch.cuda.is_available() +device = torch.device('cuda' if use_cuda else 'cpu') + +methods = ['d2', 'lfnet', 'superpoint', 'r2d2', 'aslfeat', 'disk', + 'alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms'] +names = ['D2-Net(MS)', 'LF-Net(MS)', 'SuperPoint', 'R2D2(MS)', 'ASLFeat(MS)', 'DISK', + 'ALike-N', 'ALike-L', 'ALike-N(MS)', 'ALike-L(MS)'] + +top_k = None +n_i = 52 +n_v = 56 +cache_dir = 'hseq/cache' +dataset_path = 'hseq/hpatches-sequences-release' + + +def generate_read_function(method, extension='ppm'): + def read_function(seq_name, im_idx): + aux = np.load(os.path.join(dataset_path, seq_name, '%d.%s.%s' % (im_idx, extension, method))) + if top_k is None: + return aux['keypoints'], aux['descriptors'] + else: + assert ('scores' in aux) + ids = np.argsort(aux['scores'])[-top_k:] + return aux['keypoints'][ids, :], aux['descriptors'][ids, :] + + return read_function + + +def mnn_matcher(descriptors_a, descriptors_b): + device = descriptors_a.device + sim = descriptors_a @ descriptors_b.t() + nn12 = torch.max(sim, dim=1)[1] + nn21 = torch.max(sim, dim=0)[1] + ids1 = torch.arange(0, sim.shape[0], device=device) + mask = (ids1 == nn21[nn12]) + matches = torch.stack([ids1[mask], nn12[mask]]) + return matches.t().data.cpu().numpy() + + +def homo_trans(coord, H): + kpt_num = coord.shape[0] + homo_coord = np.concatenate((coord, np.ones((kpt_num, 1))), axis=-1) + proj_coord = np.matmul(H, homo_coord.T).T + proj_coord = proj_coord / proj_coord[:, 2][..., None] + proj_coord = proj_coord[:, 0:2] + return proj_coord + + +def benchmark_features(read_feats): + lim = [1, 5] + rng = np.arange(lim[0], lim[1] + 1) + + seq_names = sorted(os.listdir(dataset_path)) + + n_feats = [] + n_matches = [] + seq_type = [] + i_err = {thr: 0 for thr in rng} + v_err = {thr: 0 for thr in rng} + + i_err_homo = {thr: 0 for thr in rng} + v_err_homo = {thr: 0 for thr in rng} + + for seq_idx, seq_name in tqdm(enumerate(seq_names), total=len(seq_names)): + keypoints_a, descriptors_a = read_feats(seq_name, 1) + n_feats.append(keypoints_a.shape[0]) + + # =========== compute homography + ref_img = cv2.imread(os.path.join(dataset_path, seq_name, '1.ppm')) + ref_img_shape = ref_img.shape + + for im_idx in range(2, 7): + keypoints_b, descriptors_b = read_feats(seq_name, im_idx) + n_feats.append(keypoints_b.shape[0]) + + matches = mnn_matcher( + torch.from_numpy(descriptors_a).to(device=device), + torch.from_numpy(descriptors_b).to(device=device) + ) + + homography = np.loadtxt(os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx))) + + pos_a = keypoints_a[matches[:, 0], : 2] + pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1) + pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h))) + pos_b_proj = pos_b_proj_h[:, : 2] / pos_b_proj_h[:, 2:] + + pos_b = keypoints_b[matches[:, 1], : 2] + + dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1)) + + n_matches.append(matches.shape[0]) + seq_type.append(seq_name[0]) + + if dist.shape[0] == 0: + dist = np.array([float("inf")]) + + for thr in rng: + if seq_name[0] == 'i': + i_err[thr] += np.mean(dist <= thr) + else: + v_err[thr] += np.mean(dist <= thr) + + # =========== compute homography + gt_homo = homography + pred_homo, _ = cv2.findHomography(keypoints_a[matches[:, 0], : 2], keypoints_b[matches[:, 1], : 2], + cv2.RANSAC) + if pred_homo is None: + homo_dist = np.array([float("inf")]) + else: + corners = np.array([[0, 0], + [ref_img_shape[1] - 1, 0], + [0, ref_img_shape[0] - 1], + [ref_img_shape[1] - 1, ref_img_shape[0] - 1]]) + real_warped_corners = homo_trans(corners, gt_homo) + warped_corners = homo_trans(corners, pred_homo) + homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1)) + + for thr in rng: + if seq_name[0] == 'i': + i_err_homo[thr] += np.mean(homo_dist <= thr) + else: + v_err_homo[thr] += np.mean(homo_dist <= thr) + + seq_type = np.array(seq_type) + n_feats = np.array(n_feats) + n_matches = np.array(n_matches) + + return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches] + + +if __name__ == '__main__': + errors = {} + for method in methods: + output_file = os.path.join(cache_dir, method + '.npy') + read_function = generate_read_function(method) + if os.path.exists(output_file): + errors[method] = np.load(output_file, allow_pickle=True) + else: + extract_method(method) + errors[method] = benchmark_features(read_function) + np.save(output_file, errors[method]) + + for name, method in zip(names, methods): + i_err, v_err, i_err_hom, v_err_hom, _ = errors[method] + + print(f"====={name}=====") + print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end='') + for thr in range(1, 4): + err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5) + print(f"{err * 100:.2f}%", end=' ') + for thr in range(1, 4): + err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5) + print(f"{err_hom * 100:.2f}%", end=' ') + print('') diff --git a/imcui/third_party/ALIKE/hseq/extract.py b/imcui/third_party/ALIKE/hseq/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..1342e40dd2d0e1d1986e90f995c95b17972ec4e1 --- /dev/null +++ b/imcui/third_party/ALIKE/hseq/extract.py @@ -0,0 +1,159 @@ +import os +import sys +import cv2 +from pathlib import Path +import numpy as np +import torch +import torch.utils.data as data +from tqdm import tqdm +from copy import deepcopy +from torchvision.transforms import ToTensor + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from alike import ALike, configs + +dataset_root = 'hseq/hpatches-sequences-release' +use_cuda = torch.cuda.is_available() +device = 'cuda' if use_cuda else 'cpu' +methods = ['alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms'] + + +class HPatchesDataset(data.Dataset): + def __init__(self, root: str = dataset_root, alteration: str = 'all'): + """ + Args: + root: dataset root path + alteration: # 'all', 'i' for illumination or 'v' for viewpoint + """ + assert (Path(root).exists()), f"Dataset root path {root} dose not exist!" + self.root = root + + # get all image file name + self.image0_list = [] + self.image1_list = [] + self.homographies = [] + folders = [x for x in Path(self.root).iterdir() if x.is_dir()] + self.seqs = [] + for folder in folders: + if alteration == 'i' and folder.stem[0] != 'i': + continue + if alteration == 'v' and folder.stem[0] != 'v': + continue + + self.seqs.append(folder) + + self.len = len(self.seqs) + assert (self.len > 0), f'Can not find PatchDataset in path {self.root}' + + def __getitem__(self, item): + folder = self.seqs[item] + + imgs = [] + homos = [] + for i in range(1, 7): + img = cv2.imread(str(folder / f'{i}.ppm'), cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # HxWxC + imgs.append(img) + + if i != 1: + homo = np.loadtxt(str(folder / f'H_1_{i}')).astype('float32') + homos.append(homo) + + return imgs, homos, folder.stem + + def __len__(self): + return self.len + + def name(self): + return self.__class__ + + +def extract_multiscale(model, img, scale_f=2 ** 0.5, + min_scale=1., max_scale=1., + min_size=0., max_size=99999., + image_size_max=99999, + n_k=0, sort=False): + H_, W_, three = img.shape + assert three == 3, "input image shape should be [HxWx3]" + + old_bm = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = False # speedup + + # ==================== image size constraint + image = deepcopy(img) + max_hw = max(H_, W_) + if max_hw > image_size_max: + ratio = float(image_size_max / max_hw) + image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio) + + # ==================== convert image to tensor + H, W, three = image.shape + image = ToTensor()(image).unsqueeze(0) + image = image.to(device) + + s = 1.0 # current scale factor + keypoints, descriptors, scores, scores_maps, descriptor_maps = [], [], [], [], [] + while s + 0.001 >= max(min_scale, min_size / max(H, W)): + if s - 0.001 <= min(max_scale, max_size / max(H, W)): + nh, nw = image.shape[2:] + + # extract descriptors + with torch.no_grad(): + descriptor_map, scores_map = model.extract_dense_map(image) + keypoints_, descriptors_, scores_, _ = model.dkd(scores_map, descriptor_map) + + keypoints.append(keypoints_[0]) + descriptors.append(descriptors_[0]) + scores.append(scores_[0]) + + s /= scale_f + + # down-scale the image for next iteration + nh, nw = round(H * s), round(W * s) + image = torch.nn.functional.interpolate(image, (nh, nw), mode='bilinear', align_corners=False) + + # restore value + torch.backends.cudnn.benchmark = old_bm + + keypoints = torch.cat(keypoints) + descriptors = torch.cat(descriptors) + scores = torch.cat(scores) + keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W_ - 1, H_ - 1]]) + + if sort or 0 < n_k < len(keypoints): + indices = torch.argsort(scores, descending=True) + keypoints = keypoints[indices] + descriptors = descriptors[indices] + scores = scores[indices] + + if 0 < n_k < len(keypoints): + keypoints = keypoints[0:n_k] + descriptors = descriptors[0:n_k] + scores = scores[0:n_k] + + return {'keypoints': keypoints, 'descriptors': descriptors, 'scores': scores} + + +def extract_method(m): + hpatches = HPatchesDataset(root=dataset_root, alteration='all') + model = m[:7] + min_scale = 0.3 if m[8:] == 'ms' else 1.0 + + model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000) + + progbar = tqdm(hpatches, desc='Extracting for {}'.format(m)) + for imgs, homos, seq_name in progbar: + for i in range(1, 7): + img = imgs[i - 1] + pred = extract_multiscale(model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000) + kpts, descs, scores = pred['keypoints'], pred['descriptors'], pred['scores'] + + with open(os.path.join(dataset_root, seq_name, f'{i}.ppm.{m}'), 'wb') as f: + np.savez(f, keypoints=kpts.cpu().numpy(), + scores=scores.cpu().numpy(), + descriptors=descs.cpu().numpy()) + + +if __name__ == '__main__': + for method in methods: + extract_method(method) diff --git a/imcui/third_party/ALIKE/soft_detect.py b/imcui/third_party/ALIKE/soft_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..2d23cd13b8a7db9b0398fdc1b235564222d30c90 --- /dev/null +++ b/imcui/third_party/ALIKE/soft_detect.py @@ -0,0 +1,194 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +# coordinates system +# ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ] +# | ----------------------------- +# | | | +# | | | +# | | | +# | | image | +# | | | +# | | | +# | | | +# | |---------------------------| +# v +# [ y: range=-1.0~1.0; h: range=0~H ] + +def simple_nms(scores, nms_radius: int): + """ Fast Non-maximum suppression to remove nearby points """ + assert (nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def sample_descriptor(descriptor_map, kpts, bilinear_interp=False): + """ + :param descriptor_map: BxCxHxW + :param kpts: list, len=B, each is Nx2 (keypoints) [h,w] + :param bilinear_interp: bool, whether to use bilinear interpolation + :return: descriptors: list, len=B, each is NxD + """ + batch_size, channel, height, width = descriptor_map.shape + + descriptors = [] + for index in range(batch_size): + kptsi = kpts[index] # Nx2,(x,y) + + if bilinear_interp: + descriptors_ = torch.nn.functional.grid_sample(descriptor_map[index].unsqueeze(0), kptsi.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, :, 0, :] # CxN + else: + kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]]) + kptsi = kptsi.long() + descriptors_ = descriptor_map[index, :, kptsi[:, 1], kptsi[:, 0]] # CxN + + descriptors_ = torch.nn.functional.normalize(descriptors_, p=2, dim=0) + descriptors.append(descriptors_.t()) + + return descriptors + + +class DKD(nn.Module): + def __init__(self, radius=2, top_k=0, scores_th=0.2, n_limit=20000): + """ + Args: + radius: soft detection radius, kernel size is (2 * radius + 1) + top_k: top_k > 0: return top k keypoints + scores_th: top_k <= 0 threshold mode: scores_th > 0: return keypoints with scores>scores_th + else: return keypoints with scores > scores.mean() + n_limit: max number of keypoint in threshold mode + """ + super().__init__() + self.radius = radius + self.top_k = top_k + self.scores_th = scores_th + self.n_limit = n_limit + self.kernel_size = 2 * self.radius + 1 + self.temperature = 0.1 # tuned temperature + self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius) + + # local xy grid + x = torch.linspace(-self.radius, self.radius, self.kernel_size) + # (kernel_size*kernel_size) x 2 : (w,h) + self.hw_grid = torch.stack(torch.meshgrid([x, x])).view(2, -1).t()[:, [1, 0]] + + def detect_keypoints(self, scores_map, sub_pixel=True): + b, c, h, w = scores_map.shape + scores_nograd = scores_map.detach() + # nms_scores = simple_nms(scores_nograd, self.radius) + nms_scores = simple_nms(scores_nograd, 2) + + # remove border + nms_scores[:, :, :self.radius + 1, :] = 0 + nms_scores[:, :, :, :self.radius + 1] = 0 + nms_scores[:, :, h - self.radius:, :] = 0 + nms_scores[:, :, :, w - self.radius:] = 0 + + # detect keypoints without grad + if self.top_k > 0: + topk = torch.topk(nms_scores.view(b, -1), self.top_k) + indices_keypoints = topk.indices # B x top_k + else: + if self.scores_th > 0: + masks = nms_scores > self.scores_th + if masks.sum() == 0: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + else: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + masks = masks.reshape(b, -1) + + indices_keypoints = [] # list, B x (any size) + scores_view = scores_nograd.reshape(b, -1) + for mask, scores in zip(masks, scores_view): + indices = mask.nonzero(as_tuple=False)[:, 0] + if len(indices) > self.n_limit: + kpts_sc = scores[indices] + sort_idx = kpts_sc.sort(descending=True)[1] + sel_idx = sort_idx[:self.n_limit] + indices = indices[sel_idx] + indices_keypoints.append(indices) + + keypoints = [] + scoredispersitys = [] + kptscores = [] + if sub_pixel: + # detect soft keypoints with grad backpropagation + patches = self.unfold(scores_map) # B x (kernel**2) x (H*W) + self.hw_grid = self.hw_grid.to(patches) # to device + for b_idx in range(b): + patch = patches[b_idx].t() # (H*W) x (kernel**2) + indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M + patch_scores = patch[indices_kpt] # M x (kernel**2) + + # max is detached to prevent undesired backprop loops in the graph + max_v = patch_scores.max(dim=1).values.detach()[:, None] + x_exp = ((patch_scores - max_v) / self.temperature).exp() # M * (kernel**2), in [0, 1] + + # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} } + xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] # Soft-argmax, Mx2 + + hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius, + dim=-1) ** 2 + scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1) + + # compute result keypoints + keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2 + keypoints_xy = keypoints_xy_nms + xy_residual + keypoints_xy = keypoints_xy / keypoints_xy.new_tensor( + [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1) + + kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN + + keypoints.append(keypoints_xy) + scoredispersitys.append(scoredispersity) + kptscores.append(kptscore) + else: + for b_idx in range(b): + indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M + keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2 + keypoints_xy = keypoints_xy_nms / keypoints_xy_nms.new_tensor( + [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1) + kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN + keypoints.append(keypoints_xy) + scoredispersitys.append(None) + kptscores.append(kptscore) + + return keypoints, scoredispersitys, kptscores + + def forward(self, scores_map, descriptor_map, sub_pixel=False): + """ + :param scores_map: Bx1xHxW + :param descriptor_map: BxCxHxW + :param sub_pixel: whether to use sub-pixel keypoint detection + :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0 + """ + keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map, + sub_pixel) + + descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel) + + # keypoints: B M 2 + # descriptors: B M D + # scoredispersitys: + return keypoints, descriptors, kptscores, scoredispersitys diff --git a/imcui/third_party/ASpanFormer/.github/workflows/sync.yml b/imcui/third_party/ASpanFormer/.github/workflows/sync.yml new file mode 100644 index 0000000000000000000000000000000000000000..42e762d5299095226503f3a8cebfeef440ef68d7 --- /dev/null +++ b/imcui/third_party/ASpanFormer/.github/workflows/sync.yml @@ -0,0 +1,39 @@ +name: Upstream Sync + +permissions: + contents: write + +on: + schedule: + - cron: "0 0 * * *" # every day + workflow_dispatch: + +jobs: + sync_latest_from_upstream: + name: Sync latest commits from upstream repo + runs-on: ubuntu-latest + if: ${{ github.event.repository.fork }} + + steps: + # Step 1: run a standard checkout action + - name: Checkout target repo + uses: actions/checkout@v3 + + # Step 2: run the sync action + - name: Sync upstream changes + id: sync + uses: aormsby/Fork-Sync-With-Upstream-action@v3.4 + with: + upstream_sync_repo: apple/ml-aspanformer + upstream_sync_branch: main + target_sync_branch: main + target_repo_token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, no need to set + + # Set test_mode true to run tests instead of the true action!! + test_mode: false + + - name: Sync check + if: failure() + run: | + echo "::error::Due to insufficient permissions, synchronization failed (as expected). Please go to the repository homepage and manually perform [Sync fork]." + exit 1 diff --git a/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2b44807696ec280672c8f40650fd04fa4d8a36 --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py @@ -0,0 +1,10 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent / '../../../')) +from src.config.default import _CN as cfg + +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' + +cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0 +cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20] +cfg.ASPAN.COARSE.TRAIN_RES = [480,640] diff --git a/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py new file mode 100644 index 0000000000000000000000000000000000000000..886d10d8f55533c8021bcca8395b5a2897fb8734 --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py @@ -0,0 +1,11 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent / '../../../')) +from src.config.default import _CN as cfg + +cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20] +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' + +cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False +cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0 +cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] diff --git a/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f0b9c04cbf3f466e413b345272afe7d7fe4274ea --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py @@ -0,0 +1,21 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent / '../../../')) +from src.config.default import _CN as cfg + +cfg.ASPAN.COARSE.COARSEST_LEVEL= [36,36] +cfg.ASPAN.COARSE.TRAIN_RES = [832,832] +cfg.ASPAN.COARSE.TEST_RES = [1152,1152] +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' + +cfg.TRAINER.CANONICAL_LR = 8e-3 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 +cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] + +# pose estimation +cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 + +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 +cfg.ASPAN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 diff --git a/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py new file mode 100644 index 0000000000000000000000000000000000000000..1202080b234562d8cc65d924d7cccf0336b9f7c0 --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py @@ -0,0 +1,20 @@ +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent / '../../../')) +from src.config.default import _CN as cfg + +cfg.ASPAN.COARSE.COARSEST_LEVEL= [26,26] +cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False + +cfg.TRAINER.CANONICAL_LR = 8e-3 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 +cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] + +# pose estimation +cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 + +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 +cfg.ASPAN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 diff --git a/imcui/third_party/ASpanFormer/configs/data/__init__.py b/imcui/third_party/ASpanFormer/configs/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/ASpanFormer/configs/data/base.py b/imcui/third_party/ASpanFormer/configs/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..03aab160fa4137ccc04380f94854a56fbb549074 --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/data/base.py @@ -0,0 +1,35 @@ +""" +The data config will be the last one merged into the main config. +Setups in data configs will override all existed setups! +""" + +from yacs.config import CfgNode as CN +_CN = CN() +_CN.DATASET = CN() +_CN.TRAINER = CN() + +# training data config +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +# validation set config +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None +_CN.DATASET.VAL_INTRINSIC_PATH = None + +# testing data config +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None +_CN.DATASET.TEST_INTRINSIC_PATH = None + +# dataset config +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +cfg = _CN diff --git a/imcui/third_party/ASpanFormer/configs/data/megadepth_test_1500.py b/imcui/third_party/ASpanFormer/configs/data/megadepth_test_1500.py new file mode 100644 index 0000000000000000000000000000000000000000..9616432f52a693ed84f3f12b9b85470b23410eee --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/data/megadepth_test_1500.py @@ -0,0 +1,13 @@ +from configs.data.base import cfg + +TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info" + +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" +cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" +cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" + +cfg.DATASET.MGDPT_IMG_RESIZE = 1152 +cfg.DATASET.MGDPT_IMG_PAD=True +cfg.DATASET.MGDPT_DF =8 +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py b/imcui/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py new file mode 100644 index 0000000000000000000000000000000000000000..8f9b01fdaed254e10b3d55980499b88a00060f04 --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py @@ -0,0 +1,22 @@ +from configs.data.base import cfg + + +TRAIN_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train" +cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" +cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 + +TEST_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" +cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" +cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +# 368 scenes in total for MegaDepth +# (with difficulty balanced (further split each scene to 3 sub-scenes)) +cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100 + +cfg.DATASET.MGDPT_IMG_RESIZE = 832 # for training on 32GB meme GPUs diff --git a/imcui/third_party/ASpanFormer/configs/data/scannet_test_1500.py b/imcui/third_party/ASpanFormer/configs/data/scannet_test_1500.py new file mode 100644 index 0000000000000000000000000000000000000000..60e560fa01d73345200aaca10961449fdf3e9fbe --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/data/scannet_test_1500.py @@ -0,0 +1,11 @@ +from configs.data.base import cfg + +TEST_BASE_PATH = "assets/scannet_test_1500" + +cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" +cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" +cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" +cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" +cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" + +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 diff --git a/imcui/third_party/ASpanFormer/configs/data/scannet_trainval.py b/imcui/third_party/ASpanFormer/configs/data/scannet_trainval.py new file mode 100644 index 0000000000000000000000000000000000000000..c38d6440e2b4ec349e5f168909c7f8c367408813 --- /dev/null +++ b/imcui/third_party/ASpanFormer/configs/data/scannet_trainval.py @@ -0,0 +1,17 @@ +from configs.data.base import cfg + + +TRAIN_BASE_PATH = "data/scannet/index" +cfg.DATASET.TRAINVAL_DATA_SOURCE = "ScanNet" +cfg.DATASET.TRAIN_DATA_ROOT = "data/scannet/train" +cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_data/train" +cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/scene_data/train_list/scannet_all.txt" +cfg.DATASET.TRAIN_INTRINSIC_PATH = f"{TRAIN_BASE_PATH}/intrinsics.npz" + +TEST_BASE_PATH = "assets/scannet_test_1500" +cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" +cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" +cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH +cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" +cfg.DATASET.VAL_INTRINSIC_PATH = cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val diff --git a/imcui/third_party/ASpanFormer/demo/demo.py b/imcui/third_party/ASpanFormer/demo/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d8a91f131b30e131cbdd6bf8ee44d53a0b256d --- /dev/null +++ b/imcui/third_party/ASpanFormer/demo/demo.py @@ -0,0 +1,64 @@ +import os +import sys +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, ROOT_DIR) + +from src.ASpanFormer.aspanformer import ASpanFormer +from src.config.default import get_cfg_defaults +from src.utils.misc import lower_config +import demo_utils +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +import cv2 +import torch +import numpy as np + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--config_path', type=str, default='../configs/aspan/outdoor/aspan_test.py', + help='path for config file.') +parser.add_argument('--img0_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg', + help='path for image0.') +parser.add_argument('--img1_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg', + help='path for image1.') +parser.add_argument('--weights_path', type=str, default='../weights/outdoor.ckpt', + help='path for model weights.') +parser.add_argument('--long_dim0', type=int, default=1024, + help='resize for longest dim of image0.') +parser.add_argument('--long_dim1', type=int, default=1024, + help='resize for longest dim of image1.') + +args = parser.parse_args() + + +if __name__=='__main__': + config = get_cfg_defaults() + config.merge_from_file(args.config_path) + _config = lower_config(config) + matcher = ASpanFormer(config=_config['aspan']) + state_dict = torch.load(args.weights_path, map_location='cpu')['state_dict'] + matcher.load_state_dict(state_dict,strict=False) + matcher.to(device),matcher.eval() + + img0,img1=cv2.imread(args.img0_path),cv2.imread(args.img1_path) + img0_g,img1_g=cv2.imread(args.img0_path,0),cv2.imread(args.img1_path,0) + img0,img1=demo_utils.resize(img0,args.long_dim0),demo_utils.resize(img1,args.long_dim1) + img0_g,img1_g=demo_utils.resize(img0_g,args.long_dim0),demo_utils.resize(img1_g,args.long_dim1) + data={'image0':torch.from_numpy(img0_g/255.)[None,None].to(device).float(), + 'image1':torch.from_numpy(img1_g/255.)[None,None].to(device).float()} + with torch.no_grad(): + matcher(data,online_resize=True) + corr0,corr1=data['mkpts0_f'].cpu().numpy(),data['mkpts1_f'].cpu().numpy() + + F_hat,mask_F=cv2.findFundamentalMat(corr0,corr1,method=cv2.FM_RANSAC,ransacReprojThreshold=1) + if mask_F is not None: + mask_F=mask_F[:,0].astype(bool) + else: + mask_F=np.zeros_like(corr0[:,0]).astype(bool) + + #visualize match + display=demo_utils.draw_match(img0,img1,corr0,corr1) + display_ransac=demo_utils.draw_match(img0,img1,corr0[mask_F],corr1[mask_F]) + cv2.imwrite('match.png',display) + cv2.imwrite('match_ransac.png',display_ransac) + print(len(corr1),len(corr1[mask_F])) \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/demo/demo_utils.py b/imcui/third_party/ASpanFormer/demo/demo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a104e25d3f5ee8b7efb6cc5fa0dc27378e22c83f --- /dev/null +++ b/imcui/third_party/ASpanFormer/demo/demo_utils.py @@ -0,0 +1,44 @@ +import cv2 +import numpy as np + +def resize(image,long_dim): + h,w=image.shape[0],image.shape[1] + image=cv2.resize(image,(int(w*long_dim/max(h,w)),int(h*long_dim/max(h,w)))) + return image + +def draw_points(img,points,color=(0,255,0),radius=3): + dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])] + for i in range(points.shape[0]): + cv2.circle(img, dp[i],radius=radius,color=color) + return img + + +def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None): + if resize is not None: + scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]] + img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) + corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis] + corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])] + corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])] + + assert len(corr1) == len(corr2) + + draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))] + if color is None: + color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier] + if len(color)==1: + display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None, + matchColor=color[0], + singlePointColor=color[0], + flags=4 + ) + else: + height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1] + display=np.zeros([height,width,3],np.uint8) + display[:img1.shape[0],:img1.shape[1]]=img1 + display[:img2.shape[0],img1.shape[1]:]=img2 + for i in range(len(corr1)): + left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1]) + cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2])) + cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA) + return display \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/environment.yaml b/imcui/third_party/ASpanFormer/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c52328762e971c94b447198869ec0036771bf76 --- /dev/null +++ b/imcui/third_party/ASpanFormer/environment.yaml @@ -0,0 +1,12 @@ +name: ASpanFormer +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python=3.8 + - cudatoolkit=10.2 + - pytorch=1.8.1 + - pip + - pip: + - -r requirements.txt diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/__init__.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bfd5a901e83c7e8d3b439f21afa20ac8237635e --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/__init__.py @@ -0,0 +1,2 @@ +from .aspanformer import LocalFeatureTransformer_Flow +from .utils.cvpr_ds_config import default_cfg diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dff6704976cbe9e916c6de6af9e3b755dfbd20bf --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py @@ -0,0 +1,3 @@ +from .transformer import LocalFeatureTransformer_Flow +from .loftr import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1fb6794461d043b0df4a20664e974a38240727 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py @@ -0,0 +1,199 @@ +import torch +from torch.nn import Module +import torch.nn as nn +from itertools import product +from torch.nn import functional as F +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class layernorm2d(nn.Module): + + def __init__(self,dim) : + super().__init__() + self.dim=dim + self.affine=nn.parameter.Parameter(torch.ones(dim), requires_grad=True) + self.bias=nn.parameter.Parameter(torch.zeros(dim), requires_grad=True) + + def forward(self,x): + #x: B*C*H*W + mean,std=x.mean(dim=1,keepdim=True),x.std(dim=1,keepdim=True) + return self.affine[None,:,None,None]*(x-mean)/(std+1e-6)+self.bias[None,:,None,None] + + +class HierachicalAttention(Module): + def __init__(self,d_model,nhead,nsample,radius_scale,nlevel=3): + super().__init__() + self.d_model=d_model + self.nhead=nhead + self.nsample=nsample + self.nlevel=nlevel + self.radius_scale=radius_scale + self.merge_head = nn.Sequential( + nn.Conv1d(d_model*3, d_model, kernel_size=1,bias=False), + nn.ReLU(True), + nn.Conv1d(d_model, d_model, kernel_size=1,bias=False), + ) + self.fullattention=FullAttention(d_model,nhead) + self.temp=nn.parameter.Parameter(torch.tensor(1.),requires_grad=True) + sample_offset=torch.tensor([[pos[0]-nsample[1]/2+0.5, pos[1]-nsample[1]/2+0.5] for pos in product(range(nsample[1]), range(nsample[1]))]) #r^2*2 + self.sample_offset=nn.parameter.Parameter(sample_offset,requires_grad=False) + + def forward(self,query,key,value,flow,size_q,size_kv,mask0=None, mask1=None,ds0=[4,4],ds1=[4,4]): + """ + Args: + q,k,v (torch.Tensor): [B, C, L] + mask (torch.Tensor): [B, L] + flow (torch.Tensor): [B, H, W, 4] + Return: + all_message (torch.Tensor): [B, C, H, W] + """ + + variance=flow[:,:,:,2:] + offset=flow[:,:,:,:2] #B*H*W*2 + bs=query.shape[0] + h0,w0=size_q[0],size_q[1] + h1,w1=size_kv[0],size_kv[1] + variance=torch.exp(0.5*variance)*self.radius_scale #b*h*w*2(pixel scale) + span_scale=torch.clamp((variance*2/self.nsample[1]),min=1) #b*h*w*2 + + sub_sample0,sub_sample1=[ds0,2,1],[ds1,2,1] + q_list=[F.avg_pool2d(query.view(bs,-1,h0,w0),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0] + k_list=[F.avg_pool2d(key.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] + v_list=[F.avg_pool2d(value.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] #n_level + + offset_list=[F.avg_pool2d(offset.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1)/sub_size for sub_size in sub_sample0[1:]] #n_level-1 + span_list=[F.avg_pool2d(span_scale.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1) for sub_size in sub_sample0[1:]] #n_level-1 + + if mask0 is not None: + mask0,mask1=mask0.view(bs,1,h0,w0),mask1.view(bs,1,h1,w1) + mask0_list=[-F.max_pool2d(-mask0,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0] + mask1_list=[-F.max_pool2d(-mask1,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] + else: + mask0_list=mask1_list=[None,None,None] + + message_list=[] + #full attention at coarse scale + mask0_flatten=mask0_list[0].view(bs,-1) if mask0 is not None else None + mask1_flatten=mask1_list[0].view(bs,-1) if mask1 is not None else None + message_list.append(self.fullattention(q_list[0],k_list[0],v_list[0],mask0_flatten,mask1_flatten,self.temp).view(bs,self.d_model,h0//ds0[0],w0//ds0[1])) + + for index in range(1,self.nlevel): + q,k,v=q_list[index],k_list[index],v_list[index] + mask0,mask1=mask0_list[index],mask1_list[index] + s,o=span_list[index-1],offset_list[index-1] #B*h*w(*2) + q,k,v,sample_pixel,mask_sample=self.partition_token(q,k,v,o,s,mask0) #B*Head*D*G*N(G*N=H*W for q) + message_list.append(self.group_attention(q,k,v,1,mask_sample).view(bs,self.d_model,h0//sub_sample0[index],w0//sub_sample0[index])) + #fuse + all_message=torch.cat([F.upsample(message_list[idx],scale_factor=sub_sample0[idx],mode='nearest') \ + for idx in range(self.nlevel)],dim=1).view(bs,-1,h0*w0) #b*3d*H*W + + all_message=self.merge_head(all_message).view(bs,-1,h0,w0) #b*d*H*W + return all_message + + def partition_token(self,q,k,v,offset,span_scale,maskv): + #q,k,v: B*C*H*W + #o: B*H/2*W/2*2 + #span_scale:B*H*W + bs=q.shape[0] + h,w=q.shape[2],q.shape[3] + hk,wk=k.shape[2],k.shape[3] + offset=offset.view(bs,-1,2) + span_scale=span_scale.view(bs,-1,1,2) + #B*G*2 + offset_sample=self.sample_offset[None,None]*span_scale + sample_pixel=offset[:,:,None]+offset_sample#B*G*r^2*2 + sample_norm=sample_pixel/torch.tensor([wk/2,hk/2]).to(device)[None,None,None]-1 + + q = q.view(bs, -1 , h // self.nsample[0], self.nsample[0], w // self.nsample[0], self.nsample[0]).\ + permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, self.nhead,self.d_model//self.nhead, -1,self.nsample[0]**2)#B*head*D*G*N(G*N=H*W for q) + #sample token + k=F.grid_sample(k, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2 + v=F.grid_sample(v, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2 + #import pdb;pdb.set_trace() + if maskv is not None: + mask_sample=F.grid_sample(maskv.view(bs,-1,h,w).float(),grid=sample_norm,mode='nearest')==1 #B*1*G*r^2 + else: + mask_sample=None + return q,k,v,sample_pixel,mask_sample + + + def group_attention(self,query,key,value,temp,mask_sample=None): + #q,k,v: B*Head*D*G*N(G*N=H*W for q) + bs=query.shape[0] + #import pdb;pdb.set_trace() + QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key) + if mask_sample is not None: + num_head,number_n=QK.shape[1],QK.shape[3] + QK.masked_fill_(~(mask_sample[:,:,:,None]).expand(-1,num_head,-1,number_n,-1).bool(), float(-1e8)) + # Compute the attention and the weighted average + softmax_temp = temp / query.size(2)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=-1) + queried_values = torch.einsum("bhgnm,bhdgm->bhdgn", A, value).contiguous().view(bs,self.d_model,-1) + return queried_values + + + +class FullAttention(Module): + def __init__(self,d_model,nhead): + super().__init__() + self.d_model=d_model + self.nhead=nhead + + def forward(self, q, k,v , mask0=None, mask1=None, temp=1): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + q,k,v: [N, D, L] + mask: [N, L] + Returns: + msg: [N,L] + """ + bs=q.shape[0] + q,k,v=q.view(bs,self.nhead,self.d_model//self.nhead,-1),k.view(bs,self.nhead,self.d_model//self.nhead,-1),v.view(bs,self.nhead,self.d_model//self.nhead,-1) + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nhdl,nhds->nhls", q, k) + if mask0 is not None: + QK.masked_fill_(~(mask0[:,None, :, None] * mask1[:, None, None]).bool(), float(-1e8)) + # Compute the attention and the weighted average + softmax_temp = temp / q.size(2)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=-1) + queried_values = torch.einsum("nhls,nhds->nhdl", A, v).contiguous().view(bs,self.d_model,-1) + return queried_values + + + +def elu_feature_map(x): + return F.elu(x) + 1 + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcebaa7beee978b9b8abcec8bb1bd2cc6b60870 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py @@ -0,0 +1,112 @@ +import copy +import torch +import torch.nn as nn +from .attention import LinearAttention + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None, type=None, index=0): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view( + bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, + self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + + message = self.attention( + query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view( + bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer( + config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size( + 2), "the feature number of src and transformer must be equal" + + index = 0 + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0, + type='self', index=index) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0, + type='cross', index=index) + index += 1 + else: + raise KeyError + return feat0, feat1 + diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1bed7b4f65c6b5936e9e265dfefc0d058dbfa33f --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py @@ -0,0 +1,245 @@ +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +from .attention import FullAttention, HierachicalAttention ,layernorm2d +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class messageLayer_ini(nn.Module): + + def __init__(self, d_model, d_flow,d_value, nhead): + super().__init__() + super(messageLayer_ini, self).__init__() + + self.d_model = d_model + self.d_flow = d_flow + self.d_value=d_value + self.nhead = nhead + self.attention = FullAttention(d_model,nhead) + + self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) + self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) + self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False) + self.merge_head=nn.Conv1d(d_model,d_model,kernel_size=1,bias=False) + + self.merge_f= self.merge_f = nn.Sequential( + nn.Conv2d(d_model*2, d_model*2, kernel_size=1, bias=False), + nn.ReLU(True), + nn.Conv2d(d_model*2, d_model, kernel_size=1, bias=False), + ) + + self.norm1 = layernorm2d(d_model) + self.norm2 = layernorm2d(d_model) + + + def forward(self, x0, x1,pos0,pos1,mask0=None,mask1=None): + #x1,x2: b*d*L + x0,x1=self.update(x0,x1,pos1,mask0,mask1),\ + self.update(x1,x0,pos0,mask1,mask0) + return x0,x1 + + + def update(self,f0,f1,pos1,mask0,mask1): + """ + Args: + f0: [N, D, H, W] + f1: [N, D, H, W] + Returns: + f0_new: (N, d, h, w) + """ + bs,h,w=f0.shape[0],f0.shape[2],f0.shape[3] + + f0_flatten,f1_flatten=f0.view(bs,self.d_model,-1),f1.view(bs,self.d_model,-1) + pos1_flatten=pos1.view(bs,self.d_value-self.d_model,-1) + f1_flatten_v=torch.cat([f1_flatten,pos1_flatten],dim=1) + + queries,keys=self.q_proj(f0_flatten),self.k_proj(f1_flatten) + values=self.v_proj(f1_flatten_v).view(bs,self.nhead,self.d_model//self.nhead,-1) + + queried_values=self.attention(queries,keys,values,mask0,mask1) + msg=self.merge_head(queried_values).view(bs,-1,h,w) + msg=self.norm2(self.merge_f(torch.cat([f0,self.norm1(msg)],dim=1))) + return f0+msg + + + +class messageLayer_gla(nn.Module): + + def __init__(self,d_model,d_flow,d_value, + nhead,radius_scale,nsample,update_flow=True): + super().__init__() + self.d_model = d_model + self.d_flow=d_flow + self.d_value=d_value + self.nhead = nhead + self.radius_scale=radius_scale + self.update_flow=update_flow + self.flow_decoder=nn.Sequential( + nn.Conv1d(d_flow, d_flow//2, kernel_size=1, bias=False), + nn.ReLU(True), + nn.Conv1d(d_flow//2, 4, kernel_size=1, bias=False)) + self.attention=HierachicalAttention(d_model,nhead,nsample,radius_scale) + + self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) + self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False) + self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False) + + d_extra=d_flow if update_flow else 0 + self.merge_f=nn.Sequential( + nn.Conv2d(d_model*2+d_extra, d_model+d_flow, kernel_size=1, bias=False), + nn.ReLU(True), + nn.Conv2d(d_model+d_flow, d_model+d_extra, kernel_size=3,padding=1, bias=False), + ) + self.norm1 = layernorm2d(d_model) + self.norm2 = layernorm2d(d_model+d_extra) + + def forward(self, x0, x1, flow_feature0,flow_feature1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]): + """ + Args: + x0 (torch.Tensor): [B, C, H, W] + x1 (torch.Tensor): [B, C, H, W] + flow_feature0 (torch.Tensor): [B, C', H, W] + flow_feature1 (torch.Tensor): [B, C', H, W] + """ + flow0,flow1=self.decode_flow(flow_feature0,flow_feature1.shape[2:]),self.decode_flow(flow_feature1,flow_feature0.shape[2:]) + x0_new,flow_feature0_new=self.update(x0,x1,flow0.detach(),flow_feature0,pos1,mask0,mask1,ds0,ds1) + x1_new,flow_feature1_new=self.update(x1,x0,flow1.detach(),flow_feature1,pos0,mask1,mask0,ds1,ds0) + return x0_new,x1_new,flow_feature0_new,flow_feature1_new,flow0,flow1 + + def update(self,x0,x1,flow0,flow_feature0,pos1,mask0,mask1,ds0,ds1): + bs=x0.shape[0] + queries,keys=self.q_proj(x0.view(bs,self.d_model,-1)),self.k_proj(x1.view(bs,self.d_model,-1)) + x1_pos=torch.cat([x1,pos1],dim=1) + values=self.v_proj(x1_pos.view(bs,self.d_value,-1)) + msg=self.attention(queries,keys,values,flow0,x0.shape[2:],x1.shape[2:],mask0,mask1,ds0,ds1) + + if self.update_flow: + update_feature=torch.cat([x0,flow_feature0],dim=1) + else: + update_feature=x0 + msg=self.norm2(self.merge_f(torch.cat([update_feature,self.norm1(msg)],dim=1))) + update_feature=update_feature+msg + + x0_new,flow_feature0_new=update_feature[:,:self.d_model],update_feature[:,self.d_model:] + return x0_new,flow_feature0_new + + def decode_flow(self,flow_feature,kshape): + bs,h,w=flow_feature.shape[0],flow_feature.shape[2],flow_feature.shape[3] + scale_factor=torch.tensor([kshape[1],kshape[0]]).to(device)[None,None,None] + flow=self.flow_decoder(flow_feature.view(bs,-1,h*w)).permute(0,2,1).view(bs,h,w,4) + flow_coordinates=torch.sigmoid(flow[:,:,:,:2])*scale_factor + flow_var=flow[:,:,:,2:] + flow=torch.cat([flow_coordinates,flow_var],dim=-1) #B*H*W*4 + return flow + + +class flow_initializer(nn.Module): + + def __init__(self, dim, dim_flow, nhead, layer_num): + super().__init__() + self.layer_num= layer_num + self.dim = dim + self.dim_flow = dim_flow + + encoder_layer = messageLayer_ini( + dim ,dim_flow,dim+dim_flow , nhead) + self.layers_coarse = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(layer_num)]) + self.decoupler = nn.Conv2d( + self.dim, self.dim+self.dim_flow, kernel_size=1) + self.up_merge = nn.Conv2d(2*dim, dim, kernel_size=1) + + def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]): + # feat0: [B, C, H0, W0] + # feat1: [B, C, H1, W1] + # use low-res MHA to initialize flow feature + bs = feat0.size(0) + h0,w0,h1,w1=feat0.shape[2],feat0.shape[3],feat1.shape[2],feat1.shape[3] + + # coarse level + sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), \ + F.avg_pool2d(feat1, ds1, stride=ds1) + + sub_pos0,sub_pos1=F.avg_pool2d(pos0, ds0, stride=ds0), \ + F.avg_pool2d(pos1, ds1, stride=ds1) + + if mask0 is not None: + mask0,mask1=-F.max_pool2d(-mask0.view(bs,1,h0,w0),ds0,stride=ds0).view(bs,-1),\ + -F.max_pool2d(-mask1.view(bs,1,h1,w1),ds1,stride=ds1).view(bs,-1) + + for layer in self.layers_coarse: + sub_feat0, sub_feat1 = layer(sub_feat0, sub_feat1,sub_pos0,sub_pos1,mask0,mask1) + # decouple flow and visual features + decoupled_feature0, decoupled_feature1 = self.decoupler(sub_feat0),self.decoupler(sub_feat1) + + sub_feat0, sub_flow_feature0 = decoupled_feature0[:,:self.dim], decoupled_feature0[:, self.dim:] + sub_feat1, sub_flow_feature1 = decoupled_feature1[:,:self.dim], decoupled_feature1[:, self.dim:] + update_feat0, flow_feature0 = F.upsample(sub_feat0, scale_factor=ds0, mode='bilinear'),\ + F.upsample(sub_flow_feature0, scale_factor=ds0, mode='bilinear') + update_feat1, flow_feature1 = F.upsample(sub_feat1, scale_factor=ds1, mode='bilinear'),\ + F.upsample(sub_flow_feature1, scale_factor=ds1, mode='bilinear') + + feat0 = feat0+self.up_merge(torch.cat([feat0, update_feat0], dim=1)) + feat1 = feat1+self.up_merge(torch.cat([feat1, update_feat1], dim=1)) + + return feat0,feat1,flow_feature0,flow_feature1 #b*c*h*w + + +class LocalFeatureTransformer_Flow(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer_Flow, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + + self.pos_transform=nn.Conv2d(config['d_model'],config['d_flow'],kernel_size=1,bias=False) + self.ini_layer = flow_initializer(self.d_model, config['d_flow'], config['nhead'],config['ini_layer_num']) + + encoder_layer = messageLayer_gla( + config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample']) + encoder_layer_last=messageLayer_gla( + config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'],update_flow=False) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(config['layer_num']-1)]+[encoder_layer_last]) + self._reset_parameters() + + def _reset_parameters(self): + for name,p in self.named_parameters(): + if 'temp' in name or 'sample_offset' in name: + continue + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]): + """ + Args: + feat0 (torch.Tensor): [N, C, H, W] + feat1 (torch.Tensor): [N, C, H, W] + pos1,pos2: [N, C, H, W] + Outputs: + feat0: [N,-1,C] + feat1: [N,-1,C] + flow_list: [L,N,H,W,4]*1(2) + """ + bs = feat0.size(0) + + pos0,pos1=self.pos_transform(pos0),self.pos_transform(pos1) + pos0,pos1=pos0.expand(bs,-1,-1,-1),pos1.expand(bs,-1,-1,-1) + assert self.d_model == feat0.size( + 1), "the feature number of src and transformer must be equal" + + flow_list=[[],[]]# [px,py,sx,sy] + if mask0 is not None: + mask0,mask1=mask0[:,None].float(),mask1[:,None].float() + feat0,feat1, flow_feature0, flow_feature1 = self.ini_layer(feat0, feat1,pos0,pos1,mask0,mask1,ds0,ds1) + for layer in self.layers: + feat0,feat1,flow_feature0,flow_feature1,flow0,flow1=layer(feat0,feat1,flow_feature0,flow_feature1,pos0,pos1,mask0,mask1,ds0,ds1) + flow_list[0].append(flow0) + flow_list[1].append(flow1) + flow_list[0]=torch.stack(flow_list[0],dim=0) + flow_list[1]=torch.stack(flow_list[1],dim=0) + feat0, feat1 = feat0.permute(0, 2, 3, 1).view(bs, -1, self.d_model), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model) + return feat0, feat1, flow_list \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e42f2438abda5883796cea9f379380fa6ad7d7c1 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +from torchvision import transforms +from einops.einops import rearrange + +from .backbone import build_backbone +from .utils.position_encoding import PositionEncodingSine +from .aspan_module import LocalFeatureTransformer_Flow, LocalFeatureTransformer, FinePreprocess +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class ASpanFormer(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'],pre_scaling=[config['coarse']['train_res'],config['coarse']['test_res']]) + self.loftr_coarse = LocalFeatureTransformer_Flow(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + self.coarsest_level=config['coarse']['coarsest_level'] + + def forward(self, data, online_resize=False): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + if online_resize: + assert data['image0'].shape[0]==1 and data['image1'].shape[1]==1 + self.resize_input(data,self.config['coarse']['train_res']) + else: + data['pos_scale0'],data['pos_scale1']=None,None + + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone( + torch.cat([data['image0'], data['image1']], dim=0)) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split( + data['bs']), feats_f.split(data['bs']) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone( + data['image0']), self.backbone(data['image1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(feat_c0,data['pos_scale0']), self.pos_encoding(feat_c1,data['pos_scale1']) + feat_c0 = rearrange(feat_c0, 'n c h w -> n c h w ') + feat_c1 = rearrange(feat_c1, 'n c h w -> n c h w ') + + #TODO:adjust ds + ds0=[int(data['hw0_c'][0]/self.coarsest_level[0]),int(data['hw0_c'][1]/self.coarsest_level[1])] + ds1=[int(data['hw1_c'][0]/self.coarsest_level[0]),int(data['hw1_c'][1]/self.coarsest_level[1])] + if online_resize: + ds0,ds1=[4,4],[4,4] + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten( + -2), data['mask1'].flatten(-2) + feat_c0, feat_c1, flow_list = self.loftr_coarse( + feat_c0, feat_c1,pos_encoding0,pos_encoding1,mask_c0,mask_c1,ds0,ds1) + + # 3. match coarse-level and register predicted offset + self.coarse_matching(feat_c0, feat_c1, flow_list,data, + mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess( + feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine( + feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + # 6. resize match coordinates back to input resolution + if online_resize: + data['mkpts0_f']*=data['online_resize_scale0'] + data['mkpts1_f']*=data['online_resize_scale1'] + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + if 'sample_offset' in k: + state_dict.pop(k) + else: + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) + + def resize_input(self,data,train_res,df=32): + h0,w0,h1,w1=data['image0'].shape[2],data['image0'].shape[3],data['image1'].shape[2],data['image1'].shape[3] + data['image0'],data['image1']=self.resize_df(data['image0'],df),self.resize_df(data['image1'],df) + + if len(train_res)==1: + train_res_h=train_res_w=train_res + else: + train_res_h,train_res_w=train_res[0],train_res[1] + data['pos_scale0'],data['pos_scale1']=[train_res_h/data['image0'].shape[2],train_res_w/data['image0'].shape[3]],\ + [train_res_h/data['image1'].shape[2],train_res_w/data['image1'].shape[3]] + data['online_resize_scale0'],data['online_resize_scale1']=torch.tensor([w0/data['image0'].shape[3],h0/data['image0'].shape[2]])[None].to(device),\ + torch.tensor([w1/data['image1'].shape[3],h1/data['image1'].shape[2]])[None].to(device) + + def resize_df(self,image,df=32): + h,w=image.shape[2],image.shape[3] + h_new,w_new=h//df*df,w//df*df + if h!=h_new or w!=w_new: + img_new=transforms.Resize([h_new,w_new]).forward(image) + else: + img_new=image + return img_new diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e731b3f53ab367c89ef0ea8e1cbffb0d990775 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py @@ -0,0 +1,11 @@ +from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 + + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..985e5b3f273a51e51447a8025ca3aadbe46752eb --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return [x4_out, x2_out] diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..281a410e02465dec1d68ab69f48673268d1d3002 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py @@ -0,0 +1,331 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +from time import time +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature=nn.parameter.Parameter(torch.tensor(10.), requires_grad=True) + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + offset: [layer, B, H, W, 4] (*2) + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + if self.match_type == 'dual_softmax': + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) * self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + elif self.match_type == 'sinkhorn': + # sinkhorn, dustbin included + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) + if mask_c0 is not None: + sim_matrix[:, :L, :S].masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + + # build uniform prior & use sinkhorn + log_assign_matrix = self.log_optimal_transport( + sim_matrix, self.bin_score, self.skh_iters) + assign_matrix = log_assign_matrix.exp() + conf_matrix = assign_matrix[:, :-1, :-1] + + # filter prediction with dustbin score (only in evaluation mode) + if not self.training and self.skh_prefilter: + filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] + filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] + conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 + conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 + + if self.config['sparse_spvs']: + data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + + data.update({'conf_matrix': conf_matrix}) + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + #update predicted offset + if flow_list[0].shape[2]==flow_list[1].shape[2] and flow_list[0].shape[3]==flow_list[1].shape[3]: + flow_list=torch.stack(flow_list,dim=0) + data.update({'predict_flow':flow_list}) #[2*L*B*H*W*4] + self.get_offset_match(flow_list,data,mask_c0,mask_c1) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches + + @torch.no_grad() + def get_offset_match(self, flow_list, data,mask1,mask2): + """ + Args: + offset (torch.Tensor): [L, B, H, W, 2] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + offset1=flow_list[0] + bs,layer_num=offset1.shape[1],offset1.shape[0] + + #left side + offset1=offset1.view(layer_num,bs,-1,4) + conf1=offset1[:,:,:,2:].mean(dim=-1) + if mask1 is not None: + conf1.masked_fill_(~mask1.bool()[None].expand(layer_num,-1,-1),100) + offset1=offset1[:,:,:,:2] + self.get_offset_match_work(offset1,conf1,data,'left') + + #rihgt side + if len(flow_list)==2: + offset2=flow_list[1].view(layer_num,bs,-1,4) + conf2=offset2[:,:,:,2:].mean(dim=-1) + if mask2 is not None: + conf2.masked_fill_(~mask2.bool()[None].expand(layer_num,-1,-1),100) + offset2=offset2[:,:,:,:2] + self.get_offset_match_work(offset2,conf2,data,'right') + + + @torch.no_grad() + def get_offset_match_work(self, offset,conf, data,side): + bs,layer_num=offset.shape[1],offset.shape[0] + # 1. confidence thresholding + mask_conf= conf<2 + for index in range(bs): + mask_conf[:,index,0]=True #safe guard in case that no match survives + # 3. find offset matches + scale = data['hw0_i'][0] / data['hw0_c'][0] + l_ids,b_ids,i_ids = torch.where(mask_conf) + j_coor=offset[l_ids,b_ids,i_ids,:2] *scale#[N,2] + i_coor=torch.stack([i_ids%data['hw0_c'][1],i_ids//data['hw0_c'][1]],dim=1)*scale + #i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).to(device).float()*scale #[N,2] + # These matches is the current prediction (for visualization) + data.update({ + 'offset_bids_'+side: b_ids, # mconf == 0 => gt matches + 'offset_lids_'+side: l_ids, + 'conf'+side: conf[mask_conf] + }) + + if side=='right': + data.update({'offset_kpts0_f_'+side: j_coor.detach(), + 'offset_kpts1_f_'+side: i_coor}) + else: + data.update({'offset_kpts0_f_'+side: i_coor, + 'offset_kpts1_f_'+side: j_coor.detach()}) + + diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc57e84936c805cb387b6239ca4a5ff6154e22e --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py @@ -0,0 +1,50 @@ +from yacs.config import CfgNode as CN + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +_CN = CN() +_CN.BACKBONE_TYPE = 'ResNetFPN' +_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.FINE_CONCAT_COARSE_FEAT = True + +# 1. LoFTR-backbone (local feature CNN) config +_CN.RESNETFPN = CN() +_CN.RESNETFPN.INITIAL_DIM = 128 +_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.COARSE = CN() +_CN.COARSE.D_MODEL = 256 +_CN.COARSE.D_FFN = 256 +_CN.COARSE.NHEAD = 8 +_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.COARSE.TEMP_BUG_FIX = False + +# 3. Coarse-Matching config +_CN.MATCH_COARSE = CN() +_CN.MATCH_COARSE.THR = 0.1 +_CN.MATCH_COARSE.BORDER_RM = 2 +_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MATCH_COARSE.SKH_ITERS = 3 +_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.MATCH_COARSE.SKH_PREFILTER = True +_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory +_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. LoFTR-fine module config +_CN.FINE = CN() +_CN.FINE.D_MODEL = 128 +_CN.FINE.D_FFN = 128 +_CN.FINE.NHEAD = 8 +_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 +_CN.FINE.ATTENTION = 'linear' + +default_cfg = lower_config(_CN) diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..6e77aded52e1eb5c01e22c2738104f3b09d6922a --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py @@ -0,0 +1,74 @@ +import math +import torch +import torch.nn as nn + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py @@ -0,0 +1,54 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..07d384ae18370acb99ef00a788f628c967249ace --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py @@ -0,0 +1,61 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256),pre_scaling=None): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + self.d_model=d_model + self.max_shape=max_shape + self.pre_scaling=pre_scaling + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + + if pre_scaling[0] is not None and pre_scaling[1] is not None: + train_res,test_res=pre_scaling[0],pre_scaling[1] + x_position,y_position=x_position*train_res[1]/test_res[1],y_position*train_res[0]/test_res[0] + + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x,scaling=None): + """ + Args: + x: [N, C, H, W] + """ + if scaling is None: #onliner scaling overwrites pre_scaling + return x + self.pe[:, :, :x.size(2), :x.size(3)],self.pe[:, :, :x.size(2), :x.size(3)] + else: + pe = torch.zeros((self.d_model, *self.max_shape)) + y_position = torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0)*scaling[0] + x_position = torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0)*scaling[1] + + div_term = torch.exp(torch.arange(0, self.d_model//2, 2).float() * (-math.log(10000.0) / (self.d_model//2))) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + pe=pe.unsqueeze(0).to(x.device) + return x + pe[:, :, :x.size(2), :x.size(3)],pe[:, :, :x.size(2), :x.size(3)] \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..5cef3a7968413136f6dc9f52b6a1ec87192b006b --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py @@ -0,0 +1,151 @@ +from math import log +from loguru import logger + +import torch +from einops import repeat +from kornia.utils import create_meshgrid + +from .geometry import warp_kpts + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['ASPAN']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (no depth consistency check, since it leads to worse results experimentally) + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round().long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +@torch.no_grad() +def spvs_fine(data, config): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2]} + """ + # 1. misc + # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') + w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] + scale = config['ASPAN']['RESOLUTION'][1] + radius = config['ASPAN']['FINE_WINDOW_SIZE'] // 2 + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + # 3. compute gt + scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later + expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + data.update({"expec_f_gt": expec_f_gt}) + + +def compute_supervision_fine(data, config): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config) + else: + raise NotImplementedError diff --git a/imcui/third_party/ASpanFormer/src/__init__.py b/imcui/third_party/ASpanFormer/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/ASpanFormer/src/config/default.py b/imcui/third_party/ASpanFormer/src/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..40abd51c3f28ea6dee3c4e9fcee6efac5c080a2f --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/config/default.py @@ -0,0 +1,180 @@ +from yacs.config import CfgNode as CN +_CN = CN() + +############## ↓ ASPAN Pipeline ↓ ############## +_CN.ASPAN = CN() +_CN.ASPAN.BACKBONE_TYPE = 'ResNetFPN' +_CN.ASPAN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.ASPAN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True + +# 1. ASPAN-backbone (local feature CNN) config +_CN.ASPAN.RESNETFPN = CN() +_CN.ASPAN.RESNETFPN.INITIAL_DIM = 128 +_CN.ASPAN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 + +# 2. ASPAN-coarse module config +_CN.ASPAN.COARSE = CN() +_CN.ASPAN.COARSE.D_MODEL = 256 +_CN.ASPAN.COARSE.D_FFN = 256 +_CN.ASPAN.COARSE.D_FLOW= 128 +_CN.ASPAN.COARSE.NHEAD = 8 +_CN.ASPAN.COARSE.NLEVEL= 3 +_CN.ASPAN.COARSE.INI_LAYER_NUM = 2 +_CN.ASPAN.COARSE.LAYER_NUM = 4 +_CN.ASPAN.COARSE.NSAMPLE = [2,8] +_CN.ASPAN.COARSE.RADIUS_SCALE= 5 +_CN.ASPAN.COARSE.COARSEST_LEVEL= [26,26] +_CN.ASPAN.COARSE.TRAIN_RES = None +_CN.ASPAN.COARSE.TEST_RES = None + +# 3. Coarse-Matching config +_CN.ASPAN.MATCH_COARSE = CN() +_CN.ASPAN.MATCH_COARSE.THR = 0.2 +_CN.ASPAN.MATCH_COARSE.BORDER_RM = 2 +_CN.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3 +_CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False +_CN.ASPAN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.ASPAN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock +_CN.ASPAN.MATCH_COARSE.SPARSE_SPVS = True +_CN.ASPAN.MATCH_COARSE.LEARNABLE_DS_TEMP = True + +# 4. ASPAN-fine module config +_CN.ASPAN.FINE = CN() +_CN.ASPAN.FINE.D_MODEL = 128 +_CN.ASPAN.FINE.D_FFN = 128 +_CN.ASPAN.FINE.NHEAD = 8 +_CN.ASPAN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 +_CN.ASPAN.FINE.ATTENTION = 'linear' + +# 5. ASPAN Losses +# -- # coarse-level +_CN.ASPAN.LOSS = CN() +_CN.ASPAN.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy'] +_CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0 +# _CN.ASPAN.LOSS.SPARSE_SPVS = False +# -- - -- # focal loss (coarse) +_CN.ASPAN.LOSS.FOCAL_ALPHA = 0.25 +_CN.ASPAN.LOSS.FOCAL_GAMMA = 2.0 +_CN.ASPAN.LOSS.POS_WEIGHT = 1.0 +_CN.ASPAN.LOSS.NEG_WEIGHT = 1.0 +# _CN.ASPAN.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not. +# use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE` + +# -- # fine-level +_CN.ASPAN.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.ASPAN.LOSS.FINE_WEIGHT = 1.0 +_CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) + +# -- # flow-sloss +_CN.ASPAN.LOSS.FLOW_WEIGHT = 0.1 + + +############## Dataset ############## +_CN.DATASET = CN() +# 1. data config +# training and validating +_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file +_CN.DATASET.VAL_INTRINSIC_PATH = None +# testing +_CN.DATASET.TEST_DATA_SOURCE = None +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file +_CN.DATASET.TEST_INTRINSIC_PATH = None + +# 2. dataset config +# general options +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 +_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] + +# MegaDepth options +_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE +_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 +_CN.DATASET.MGDPT_DF = 8 + +############## Trainer ############## +_CN.TRAINER = CN() +_CN.TRAINER.WORLD_SIZE = 1 +_CN.TRAINER.CANONICAL_BS = 64 +_CN.TRAINER.CANONICAL_LR = 6e-3 +_CN.TRAINER.SCALING = None # this will be calculated automatically +_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning + +# optimizer +_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] +_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime +_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAMW_DECAY = 0.1 + +# step-based warm-up +_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_STEP = 4800 + +# learning rate scheduler +_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR +_CN.TRAINER.MSLR_GAMMA = 0.5 +_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing +_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval + +# plotting related +_CN.TRAINER.ENABLE_PLOTTING = True +_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting +_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] +_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' + +# geometric metrics and pose solver +_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.RANSAC_PIXEL_THR = 0.5 +_CN.TRAINER.RANSAC_CONF = 0.99999 +_CN.TRAINER.RANSAC_MAX_ITERS = 10000 +_CN.TRAINER.USE_MAGSACPP = False + +# data sampler for train_dataloader +_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] +# 'scene_balance' config +_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 +_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not +_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not +_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data +# 'random' config +_CN.TRAINER.RDM_REPLACEMENT = True +_CN.TRAINER.RDM_NUM_SAMPLES = None + +# gradient clipping +_CN.TRAINER.GRADIENT_CLIPPING = 0.5 + +# reproducibility +# This seed affects the data sampling. With the same seed, the data sampling is promised +# to be the same. When resume training from a checkpoint, it's better to use a different +# seed, otherwise the sampled data will be exactly the same as before resuming, which will +# cause less unique data items sampled during the entire training. +# Use of different seed values might affect the final training result, since not all data items +# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.) +_CN.TRAINER.SEED = 66 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/imcui/third_party/ASpanFormer/src/datasets/__init__.py b/imcui/third_party/ASpanFormer/src/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1860e3ae060a26e4625925861cecdc355f2b08b7 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/datasets/__init__.py @@ -0,0 +1,3 @@ +from .scannet import ScanNetDataset +from .megadepth import MegaDepthDataset + diff --git a/imcui/third_party/ASpanFormer/src/datasets/megadepth.py b/imcui/third_party/ASpanFormer/src/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..a70ac715a3f807e37bc5b87ae9446ddd2aa4fc86 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/datasets/megadepth.py @@ -0,0 +1,127 @@ +import os.path as osp +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from loguru import logger + +from src.utils.dataset import read_megadepth_gray, read_megadepth_depth + + +class MegaDepthDataset(Dataset): + def __init__(self, + root_dir, + npz_path, + mode='train', + min_overlap_score=0.4, + img_resize=None, + df=None, + img_padding=False, + depth_padding=False, + augment_fn=None, + **kwargs): + """ + Manage one scene(npz_path) of MegaDepth dataset. + + Args: + root_dir (str): megadepth root directory that has `phoenix`. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + mode (str): options are ['train', 'val', 'test'] + min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. + img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. + This is useful during training with batches and testing with memory intensive algorithms. + df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. + img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. + depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. + augment_fn (callable, optional): augments images with pre-defined visual effects. + """ + super().__init__() + self.root_dir = root_dir + self.mode = mode + self.scene_id = npz_path.split('.')[0] + + # prepare scene_info and pair_info + if mode == 'test' and min_overlap_score != 0: + logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") + min_overlap_score = 0 + self.scene_info = np.load(npz_path, allow_pickle=True) + self.pair_infos = self.scene_info['pair_infos'].copy() + del self.scene_info['pair_infos'] + self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] + + # parameters for image resizing, padding and depthmap padding + if mode == 'train': + assert img_resize is not None and img_padding and depth_padding + self.img_resize = img_resize + self.df = df + self.img_padding = img_padding + self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + + def __len__(self): + return len(self.pair_infos) + + def __getitem__(self, idx): + (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] + + # read grayscale image and mask. (1, h, w) and (h, w) + img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) + img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) + + # TODO: Support augmentation & handle seeds for each worker correctly. + image0, mask0, scale0 = read_megadepth_gray( + img_name0, self.img_resize, self.df, self.img_padding, None) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + image1, mask1, scale1 = read_megadepth_gray( + img_name1, self.img_resize, self.df, self.img_padding, None) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + + # read depth. shape: (h, w) + if self.mode in ['train', 'val']: + depth0 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) + depth1 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) + else: + depth0 = depth1 = torch.tensor([]) + + # read intrinsics of original size + K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) + K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T0 = self.scene_info['poses'][idx0] + T1 = self.scene_info['poses'][idx1] + T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) + T_1to0 = T_0to1.inverse() + + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'MegaDepth', + 'scene_id': self.scene_id, + 'pair_id': idx, + 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), + } + + # for LoFTR training + if mask0 is not None: # img_padding is True + if self.coarse_scale: + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/ASpanFormer/src/datasets/sampler.py b/imcui/third_party/ASpanFormer/src/datasets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/datasets/sampler.py @@ -0,0 +1,77 @@ +import torch +from torch.utils.data import Sampler, ConcatDataset + + +class RandomConcatSampler(Sampler): + """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset + in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. + However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. + + For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. + Args: + shuffle (bool): shuffle the random sampled indices across all sub-datsets. + repeat (int): repeatedly use the sampled indices multiple times for training. + [arXiv:1902.05509, arXiv:1901.09335] + NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) + NOTE: This sampler behaves differently with DistributedSampler. + It assume the dataset is splitted across ranks instead of replicated. + TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. + ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 + """ + def __init__(self, + data_source: ConcatDataset, + n_samples_per_subset: int, + subset_replacement: bool=True, + shuffle: bool=True, + repeat: int=1, + seed: int=None): + if not isinstance(data_source, ConcatDataset): + raise TypeError("data_source should be torch.utils.data.ConcatDataset") + + self.data_source = data_source + self.n_subset = len(self.data_source.datasets) + self.n_samples_per_subset = n_samples_per_subset + self.n_samples = self.n_subset * self.n_samples_per_subset * repeat + self.subset_replacement = subset_replacement + self.repeat = repeat + self.shuffle = shuffle + self.generator = torch.manual_seed(seed) + assert self.repeat >= 1 + + def __len__(self): + return self.n_samples + + def __iter__(self): + indices = [] + # sample from each sub-dataset + for d_idx in range(self.n_subset): + low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] + high = self.data_source.cumulative_sizes[d_idx] + if self.subset_replacement: + rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), + generator=self.generator, dtype=torch.int64) + else: # sample without replacement + len_subset = len(self.data_source.datasets[d_idx]) + rand_tensor = torch.randperm(len_subset, generator=self.generator) + low + if len_subset >= self.n_samples_per_subset: + rand_tensor = rand_tensor[:self.n_samples_per_subset] + else: # padding with replacement + rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), + generator=self.generator, dtype=torch.int64) + rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) + indices.append(rand_tensor) + indices = torch.cat(indices) + if self.shuffle: # shuffle the sampled dataset (from multiple subsets) + rand_tensor = torch.randperm(len(indices), generator=self.generator) + indices = indices[rand_tensor] + + # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) + if self.repeat > 1: + repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] + if self.shuffle: + _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] + repeat_indices = map(_choice, repeat_indices) + indices = torch.cat([indices, *repeat_indices], 0) + + assert indices.shape[0] == self.n_samples + return iter(indices.tolist()) diff --git a/imcui/third_party/ASpanFormer/src/datasets/scannet.py b/imcui/third_party/ASpanFormer/src/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..3520d34c0f08a784ddbf923846a7cb2a847b1787 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/datasets/scannet.py @@ -0,0 +1,113 @@ +from os import path as osp +from typing import Dict +from unicodedata import name + +import numpy as np +import torch +import torch.utils as utils +from numpy.linalg import inv +from src.utils.dataset import ( + read_scannet_gray, + read_scannet_depth, + read_scannet_pose, + read_scannet_intrinsic +) + + +class ScanNetDataset(utils.data.Dataset): + def __init__(self, + root_dir, + npz_path, + intrinsic_path, + mode='train', + min_overlap_score=0.4, + augment_fn=None, + pose_dir=None, + **kwargs): + """Manage one scene of ScanNet Dataset. + Args: + root_dir (str): ScanNet root directory that contains scene folders. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + intrinsic_path (str): path to depth-camera intrinsic file. + mode (str): options are ['train', 'val', 'test']. + augment_fn (callable, optional): augments images with pre-defined visual effects. + pose_dir (str): ScanNet root directory that contains all poses. + (we use a separate (optional) pose_dir since we store images and poses separately.) + """ + super().__init__() + self.root_dir = root_dir + self.pose_dir = pose_dir if pose_dir is not None else root_dir + self.mode = mode + + # prepare data_names, intrinsics and extrinsics(T) + with np.load(npz_path) as data: + self.data_names = data['name'] + if 'score' in data.keys() and mode not in ['val' or 'test']: + kept_mask = data['score'] > min_overlap_score + self.data_names = self.data_names[kept_mask] + self.intrinsics = dict(np.load(intrinsic_path)) + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + + def __len__(self): + return len(self.data_names) + + def _read_abs_pose(self, scene_name, name): + pth = osp.join(self.pose_dir, + scene_name, + 'pose', f'{name}.txt') + return read_scannet_pose(pth) + + def _compute_rel_pose(self, scene_name, name0, name1): + pose0 = self._read_abs_pose(scene_name, name0) + pose1 = self._read_abs_pose(scene_name, name1) + + return np.matmul(pose1, inv(pose0)) # (4, 4) + + def __getitem__(self, idx): + data_name = self.data_names[idx] + scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the grayscale image which will be resized to (1, 480, 640) + img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') + img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') + # TODO: Support augmentation & handle seeds for each worker correctly. + image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + + # read the depthmap which is stored as (480, 640) + if self.mode in ['train', 'val']: + depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) + depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) + else: + depth0 = depth1 = torch.tensor([]) + + # read the intrinsic of depthmap + K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), + dtype=torch.float32) + T_1to0 = T_0to1.inverse() + + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'dataset_name': 'ScanNet', + 'scene_id': scene_name, + 'pair_id': idx, + 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), + osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) + } + + return data diff --git a/imcui/third_party/ASpanFormer/src/lightning/data.py b/imcui/third_party/ASpanFormer/src/lightning/data.py new file mode 100644 index 0000000000000000000000000000000000000000..73db514b8924d647814e6c5def919c23393d3ccf --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/lightning/data.py @@ -0,0 +1,326 @@ +import os +import math +from collections import abc +from loguru import logger +from torch.utils.data.dataset import Dataset +from tqdm import tqdm +from os import path as osp +from pathlib import Path +from joblib import Parallel, delayed + +import pytorch_lightning as pl +from torch import distributed as dist +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset, + DistributedSampler, + RandomSampler, + dataloader +) + +from src.utils.augment import build_augmentor +from src.utils.dataloader import get_local_split +from src.utils.misc import tqdm_joblib +from src.utils import comm +from src.datasets.megadepth import MegaDepthDataset +from src.datasets.scannet import ScanNetDataset +from src.datasets.sampler import RandomConcatSampler + + +class MultiSceneDataModule(pl.LightningDataModule): + """ + For distributed training, each training process is assgined + only a part of the training scenes to reduce memory overhead. + """ + def __init__(self, args, config): + super().__init__() + + # 1. data config + # Train and Val should from the same data source + self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE + self.test_data_source = config.DATASET.TEST_DATA_SOURCE + # training and validating + self.train_data_root = config.DATASET.TRAIN_DATA_ROOT + self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional) + self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT + self.train_list_path = config.DATASET.TRAIN_LIST_PATH + self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH + self.val_data_root = config.DATASET.VAL_DATA_ROOT + self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional) + self.val_npz_root = config.DATASET.VAL_NPZ_ROOT + self.val_list_path = config.DATASET.VAL_LIST_PATH + self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH + # testing + self.test_data_root = config.DATASET.TEST_DATA_ROOT + self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) + self.test_npz_root = config.DATASET.TEST_NPZ_ROOT + self.test_list_path = config.DATASET.TEST_LIST_PATH + self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH + + # 2. dataset config + # general options + self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN + self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] + + # MegaDepth options + self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 + self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True + self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True + self.mgdpt_df = config.DATASET.MGDPT_DF # 8 + self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0] # 0.125. for training loftr. + + # 3.loader parameters + self.train_loader_params = { + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.val_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.test_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': True + } + + # 4. sampler + self.data_sampler = config.TRAINER.DATA_SAMPLER + self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET + self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT + self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE + self.repeat = config.TRAINER.SB_REPEAT + + # (optional) RandomSampler for debugging + + # misc configurations + self.parallel_load_data = getattr(args, 'parallel_load_data', False) + self.seed = config.TRAINER.SEED # 66 + + def setup(self, stage=None): + """ + Setup train / val / test dataset. This method will be called by PL automatically. + Args: + stage (str): 'fit' in training phase, and 'test' in testing phase. + """ + + assert stage in ['fit', 'test'], "stage must be either fit or test" + + try: + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") + except AssertionError as ae: + self.world_size = 1 + self.rank = 0 + logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") + + if stage == 'fit': + self.train_dataset = self._setup_dataset( + self.train_data_root, + self.train_npz_root, + self.train_list_path, + self.train_intrinsic_path, + mode='train', + min_overlap_score=self.min_overlap_score_train, + pose_dir=self.train_pose_root) + # setup multiple (optional) validation subsets + if isinstance(self.val_list_path, (list, tuple)): + self.val_dataset = [] + if not isinstance(self.val_npz_root, (list, tuple)): + self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] + for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): + self.val_dataset.append(self._setup_dataset( + self.val_data_root, + npz_root, + npz_list, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root)) + else: + self.val_dataset = self._setup_dataset( + self.val_data_root, + self.val_npz_root, + self.val_list_path, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root) + logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') + else: # stage == 'test + self.test_dataset = self._setup_dataset( + self.test_data_root, + self.test_npz_root, + self.test_list_path, + self.test_intrinsic_path, + mode='test', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.test_pose_root) + logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') + + def _setup_dataset(self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode='train', + min_overlap_score=0., + pose_dir=None): + """ Setup train / val / test set""" + with open(scene_list_path, 'r') as f: + npz_names = [name.split()[0] for name in f.readlines()] + + if mode == 'train': + local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) + else: + local_npz_names = npz_names + logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') + + dataset_builder = self._build_concat_dataset_parallel \ + if self.parallel_load_data \ + else self._build_concat_dataset + return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, + mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + + def _build_concat_dataset( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None + ): + datasets = [] + augment_fn = self.augment_fn if mode == 'train' else None + data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source + if data_source=='GL3D' and mode=='val': + data_source='MegaDepth' + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + if str(data_source).lower() == 'gl3d': + npz_names = [f'{n}.txt' for n in npz_names] + #npz_names=npz_names[:8] + for npz_name in tqdm(npz_names, + desc=f'[rank:{self.rank}] loading {mode} datasets', + disable=int(self.rank) != 0): + # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. + npz_path = osp.join(npz_dir, npz_name) + if data_source == 'ScanNet': + datasets.append( + ScanNetDataset(data_root, + npz_path, + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir)) + elif data_source == 'MegaDepth': + datasets.append( + MegaDepthDataset(data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale)) + else: + raise NotImplementedError() + return ConcatDataset(datasets) + + def _build_concat_dataset_parallel( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None, + ): + augment_fn = self.augment_fn if mode == 'train' else None + data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + #npz_names=npz_names[:8] + with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', + total=len(npz_names), disable=int(self.rank) != 0)): + if data_source == 'ScanNet': + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + ScanNetDataset, + data_root, + osp.join(npz_dir, x), + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir))(name) + for name in npz_names) + elif data_source == 'MegaDepth': + # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. + raise NotImplementedError() + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + MegaDepthDataset, + data_root, + osp.join(npz_dir, x), + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale))(name) + for name in npz_names) + else: + raise ValueError(f'Unknown dataset: {data_source}') + return ConcatDataset(datasets) + + def train_dataloader(self): + """ Build training dataloader for ScanNet / MegaDepth. """ + assert self.data_sampler in ['scene_balance'] + logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') + if self.data_sampler == 'scene_balance': + sampler = RandomConcatSampler(self.train_dataset, + self.n_samples_per_subset, + self.subset_replacement, + self.shuffle, self.repeat, self.seed) + else: + sampler = None + dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) + return dataloader + + def val_dataloader(self): + """ Build validation dataloader for ScanNet / MegaDepth. """ + logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') + if not isinstance(self.val_dataset, abc.Sequence): + sampler = DistributedSampler(self.val_dataset, shuffle=False) + return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) + else: + dataloaders = [] + for dataset in self.val_dataset: + sampler = DistributedSampler(dataset, shuffle=False) + dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) + return dataloaders + + def test_dataloader(self, *args, **kwargs): + logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') + sampler = DistributedSampler(self.test_dataset, shuffle=False) + return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) + + +def _build_dataset(dataset: Dataset, *args, **kwargs): + return dataset(*args, **kwargs) diff --git a/imcui/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py b/imcui/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee20cbec4628b73c08358ebf1e1906fb2c0ac13c --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py @@ -0,0 +1,276 @@ + +from collections import defaultdict +import pprint +from loguru import logger +from pathlib import Path + +import torch +import numpy as np +import pytorch_lightning as pl +from matplotlib import pyplot as plt + +from src.ASpanFormer.aspanformer import ASpanFormer +from src.ASpanFormer.utils.supervision import compute_supervision_coarse, compute_supervision_fine +from src.losses.aspan_loss import ASpanLoss +from src.optimizers import build_optimizer, build_scheduler +from src.utils.metrics import ( + compute_symmetrical_epipolar_errors,compute_symmetrical_epipolar_errors_offset_bidirectional, + compute_pose_errors, + aggregate_metrics +) +from src.utils.plotting import make_matching_figures,make_matching_figures_offset +from src.utils.comm import gather, all_gather +from src.utils.misc import lower_config, flattenList +from src.utils.profiler import PassThroughProfiler + + +class PL_ASpanFormer(pl.LightningModule): + def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): + """ + TODO: + - use the new version of PL logging API. + """ + super().__init__() + # Misc + self.config = config # full config + _config = lower_config(self.config) + self.loftr_cfg = lower_config(_config['aspan']) + self.profiler = profiler or PassThroughProfiler() + self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + + # Matcher: LoFTR + self.matcher = ASpanFormer(config=_config['aspan']) + self.loss = ASpanLoss(_config) + + # Pretrained weights + print(pretrained_ckpt) + if pretrained_ckpt: + print('load') + state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] + msg=self.matcher.load_state_dict(state_dict, strict=False) + print(msg) + logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") + + # Testing + self.dump_dir = dump_dir + + def configure_optimizers(self): + # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` + optimizer = build_optimizer(self, self.config) + scheduler = build_scheduler(self.config, optimizer) + return [optimizer], [scheduler] + + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # learning rate warm up + warmup_step = self.config.TRAINER.WARMUP_STEP + if self.trainer.global_step < warmup_step: + if self.config.TRAINER.WARMUP_TYPE == 'linear': + base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR + lr = base_lr + \ + (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ + abs(self.config.TRAINER.TRUE_LR - base_lr) + for pg in optimizer.param_groups: + pg['lr'] = lr + elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pass + else: + raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + + # update params + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() + + def _trainval_inference(self, batch): + with self.profiler.profile("Compute coarse supervision"): + compute_supervision_coarse(batch, self.config) + + with self.profiler.profile("LoFTR"): + self.matcher(batch) + + with self.profiler.profile("Compute fine supervision"): + compute_supervision_fine(batch, self.config) + + with self.profiler.profile("Compute losses"): + self.loss(batch) + + def _compute_metrics(self, batch): + with self.profiler.profile("Copmute metrics"): + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + compute_symmetrical_epipolar_errors_offset_bidirectional(batch) # compute epi_errs for offset match + compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair + + rel_pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'epi_errs_offset': [batch['epi_errs_offset_left'][batch['offset_bids_left'] == b].cpu().numpy() for b in range(bs)], #only consider left side + 'R_errs': batch['R_errs'], + 't_errs': batch['t_errs'], + 'inliers': batch['inliers']} + ret_dict = {'metrics': metrics} + return ret_dict, rel_pair_names + + + def training_step(self, batch, batch_idx): + self._trainval_inference(batch) + + # logging + if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + # scalars + for k, v in batch['loss_scalars'].items(): + if not k.startswith('loss_flow') and not k.startswith('conf_'): + self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step) + + #log offset_loss and conf for each layer and level + layer_num=self.loftr_cfg['coarse']['layer_num'] + for layer_index in range(layer_num): + log_title='layer_'+str(layer_index) + self.logger.experiment.add_scalar(log_title+'/offset_loss', batch['loss_scalars']['loss_flow_'+str(layer_index)], self.global_step) + self.logger.experiment.add_scalar(log_title+'/conf_', batch['loss_scalars']['conf_'+str(layer_index)],self.global_step) + + # net-params + if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == 'sinkhorn': + self.logger.experiment.add_scalar( + f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step) + + # figures + if self.config.TRAINER.ENABLE_PLOTTING: + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + for k, v in figures.items(): + self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step) + + #plot offset + if self.global_step%200==0: + compute_symmetrical_epipolar_errors_offset_bidirectional(batch) + figures_left = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_left') + figures_right = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right') + for k, v in figures_left.items(): + self.logger.experiment.add_figure(f'train_offset/{k}'+'_left', v, self.global_step) + figures = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right') + for k, v in figures_right.items(): + self.logger.experiment.add_figure(f'train_offset/{k}'+'_right', v, self.global_step) + + return {'loss': batch['loss']} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + if self.trainer.global_rank == 0: + self.logger.experiment.add_scalar( + 'train/avg_loss_on_epoch', avg_loss, + global_step=self.current_epoch) + + def validation_step(self, batch, batch_idx): + self._trainval_inference(batch) + + ret_dict, _ = self._compute_metrics(batch) #this func also compute the epi_errors + + val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) + figures = {self.config.TRAINER.PLOT_MODE: []} + figures_offset = {self.config.TRAINER.PLOT_MODE: []} + if batch_idx % val_plot_interval == 0: + figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE) + figures_offset=make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,'_left') + return { + **ret_dict, + 'loss_scalars': batch['loss_scalars'], + 'figures': figures, + 'figures_offset_left':figures_offset + } + + def validation_epoch_end(self, outputs): + # handle multiple validation sets + multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + multi_val_metrics = defaultdict(list) + + for valset_idx, outputs in enumerate(multi_outputs): + # since pl performs sanity_check at the very begining of the training + cur_epoch = self.trainer.current_epoch + if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + cur_epoch = -1 + + # 1. loss_scalars: dict of list, on cpu + _loss_scalars = [o['loss_scalars'] for o in outputs] + loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + + # 2. val metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + for thr in [5, 10, 20]: + multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}']) + + # 3. figures + _figures = [o['figures'] for o in outputs] + figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} + + # tensorboard records only on rank 0 + if self.trainer.global_rank == 0: + for k, v in loss_scalars.items(): + mean_v = torch.stack(v).mean() + self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + + for k, v in val_metrics_4tb.items(): + self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch) + + for k, v in figures.items(): + if self.trainer.global_rank == 0: + for plot_idx, fig in enumerate(v): + self.logger.experiment.add_figure( + f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True) + plt.close('all') + + for thr in [5, 10, 20]: + # log on all ranks for ModelCheckpoint callback to work properly + self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this + + def test_step(self, batch, batch_idx): + with self.profiler.profile("LoFTR"): + self.matcher(batch) + + ret_dict, rel_pair_names = self._compute_metrics(batch) + + with self.profiler.profile("dump_results"): + if self.dump_dir is not None: + # dump results for further analysis + keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'} + pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].shape[0] + dumps = [] + for b_id in range(bs): + item = {} + mask = batch['m_bids'] == b_id + item['pair_names'] = pair_names[b_id] + item['identifier'] = '#'.join(rel_pair_names[b_id]) + for key in keys_to_save: + item[key] = batch[key][mask].cpu().numpy() + for key in ['R_errs', 't_errs', 'inliers']: + item[key] = batch[key][b_id] + dumps.append(item) + ret_dict['dumps'] = dumps + + return ret_dict + + def test_epoch_end(self, outputs): + # metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + + # [{key: [{...}, *#bs]}, *#batch] + if self.dump_dir is not None: + Path(self.dump_dir).mkdir(parents=True, exist_ok=True) + _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] + dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] + logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') + + if self.trainer.global_rank == 0: + print(self.profiler.summary()) + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + logger.info('\n' + pprint.pformat(val_metrics_4tb)) + if self.dump_dir is not None: + np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps) diff --git a/imcui/third_party/ASpanFormer/src/losses/aspan_loss.py b/imcui/third_party/ASpanFormer/src/losses/aspan_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0cca52b36fc997415937969f26caba8c41ac2b8e --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/losses/aspan_loss.py @@ -0,0 +1,231 @@ +from loguru import logger + +import torch +import torch.nn as nn + +class ASpanLoss(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config # config under the global namespace + self.loss_config = config['aspan']['loss'] + self.match_type = self.config['aspan']['match_coarse']['match_type'] + self.sparse_spvs = self.config['aspan']['match_coarse']['sparse_spvs'] + self.flow_weight=self.config['aspan']['loss']['flow_weight'] + + # coarse-level + self.correct_thr = self.loss_config['fine_correct_thr'] + self.c_pos_w = self.loss_config['pos_weight'] + self.c_neg_w = self.loss_config['neg_weight'] + # fine-level + self.fine_type = self.loss_config['fine_type'] + + def compute_flow_loss(self,coarse_corr_gt,flow_list,h0,w0,h1,w1): + #coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]] + #flow_list: [L,B,H,W,4] + loss1=self.flow_loss_worker(flow_list[0],coarse_corr_gt[0],coarse_corr_gt[1],coarse_corr_gt[2],w1) + loss2=self.flow_loss_worker(flow_list[1],coarse_corr_gt[0],coarse_corr_gt[2],coarse_corr_gt[1],w0) + total_loss=(loss1+loss2)/2 + return total_loss + + def flow_loss_worker(self,flow,batch_indicies,self_indicies,cross_indicies,w): + bs,layer_num=flow.shape[1],flow.shape[0] + flow=flow.view(layer_num,bs,-1,4) + gt_flow=torch.stack([cross_indicies%w,cross_indicies//w],dim=1) + + total_loss_list=[] + for layer_index in range(layer_num): + cur_flow_list=flow[layer_index] + spv_flow=cur_flow_list[batch_indicies,self_indicies][:,:2] + spv_conf=cur_flow_list[batch_indicies,self_indicies][:,2:]#[#coarse,2] + l2_flow_dis=((gt_flow-spv_flow)**2) #[#coarse,2] + total_loss=(spv_conf+torch.exp(-spv_conf)*l2_flow_dis) #[#coarse,2] + total_loss_list.append(total_loss.mean()) + total_loss=torch.stack(total_loss_list,dim=-1)*self.flow_weight + return total_loss + + def compute_coarse_loss(self, conf, conf_gt, weight=None): + """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt. + Args: + conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1) + conf_gt (torch.Tensor): (N, HW0, HW1) + weight (torch.Tensor): (N, HW0, HW1) + """ + pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 + c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w + # corner case: no gt coarse-level match at all + if not pos_mask.any(): # assign a wrong gt + pos_mask[0, 0, 0] = True + if weight is not None: + weight[0, 0, 0] = 0. + c_pos_w = 0. + if not neg_mask.any(): + neg_mask[0, 0, 0] = True + if weight is not None: + weight[0, 0, 0] = 0. + c_neg_w = 0. + + if self.loss_config['coarse_type'] == 'cross_entropy': + assert not self.sparse_spvs, 'Sparse Supervision for cross-entropy not implemented!' + conf = torch.clamp(conf, 1e-6, 1-1e-6) + loss_pos = - torch.log(conf[pos_mask]) + loss_neg = - torch.log(1 - conf[neg_mask]) + if weight is not None: + loss_pos = loss_pos * weight[pos_mask] + loss_neg = loss_neg * weight[neg_mask] + return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() + elif self.loss_config['coarse_type'] == 'focal': + conf = torch.clamp(conf, 1e-6, 1-1e-6) + alpha = self.loss_config['focal_alpha'] + gamma = self.loss_config['focal_gamma'] + + if self.sparse_spvs: + pos_conf = conf[:, :-1, :-1][pos_mask] \ + if self.match_type == 'sinkhorn' \ + else conf[pos_mask] + loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() + # calculate losses for negative samples + if self.match_type == 'sinkhorn': + neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0 + neg_conf = torch.cat([conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0) + loss_neg = - alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log() + else: + # These is no dustbin for dual_softmax, so we left unmatchable patches without supervision. + # we could also add 'pseudo negtive-samples' + pass + # handle loss weights + if weight is not None: + # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out, + # but only through manually setting corresponding regions in sim_matrix to '-inf'. + loss_pos = loss_pos * weight[pos_mask] + if self.match_type == 'sinkhorn': + neg_w0 = (weight.sum(-1) != 0)[neg0] + neg_w1 = (weight.sum(1) != 0)[neg1] + neg_mask = torch.cat([neg_w0, neg_w1], 0) + loss_neg = loss_neg[neg_mask] + + loss = c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() \ + if self.match_type == 'sinkhorn' \ + else c_pos_w * loss_pos.mean() + return loss + # positive and negative elements occupy similar propotions. => more balanced loss weights needed + else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.) + loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() + loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() + if weight is not None: + loss_pos = loss_pos * weight[pos_mask] + loss_neg = loss_neg * weight[neg_mask] + return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() + # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed + else: + raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type'])) + + def compute_fine_loss(self, expec_f, expec_f_gt): + if self.fine_type == 'l2_with_std': + return self._compute_fine_loss_l2_std(expec_f, expec_f_gt) + elif self.fine_type == 'l2': + return self._compute_fine_loss_l2(expec_f, expec_f_gt) + else: + raise NotImplementedError() + + def _compute_fine_loss_l2(self, expec_f, expec_f_gt): + """ + Args: + expec_f (torch.Tensor): [M, 2] + expec_f_gt (torch.Tensor): [M, 2] + """ + correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + if correct_mask.sum() == 0: + if self.training: # this seldomly happen when training, since we pad prediction with gt + logger.warning("assign a false supervision to avoid ddp deadlock") + correct_mask[0] = True + else: + return None + flow_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1) + return flow_l2.mean() + + def _compute_fine_loss_l2_std(self, expec_f, expec_f_gt): + """ + Args: + expec_f (torch.Tensor): [M, 3] + expec_f_gt (torch.Tensor): [M, 2] + """ + # correct_mask tells you which pair to compute fine-loss + correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + + # use std as weight that measures uncertainty + std = expec_f[:, 2] + inverse_std = 1. / torch.clamp(std, min=1e-10) + weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std + + # corner case: no correct coarse match found + if not correct_mask.any(): + if self.training: # this seldomly happen during training, since we pad prediction with gt + # sometimes there is not coarse-level gt at all. + logger.warning("assign a false supervision to avoid ddp deadlock") + correct_mask[0] = True + weight[0] = 0. + else: + return None + + # l2 loss with std + flow_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1) + loss = (flow_l2 * weight[correct_mask]).mean() + + return loss + + @torch.no_grad() + def compute_c_weight(self, data): + """ compute element-wise weights for computing coarse-level loss. """ + if 'mask0' in data: + c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float() + else: + c_weight = None + return c_weight + + def forward(self, data): + """ + Update: + data (dict): update{ + 'loss': [1] the reduced loss across a batch, + 'loss_scalars' (dict): loss scalars for tensorboard_record + } + """ + loss_scalars = {} + # 0. compute element-wise loss weight + c_weight = self.compute_c_weight(data) + + # 1. coarse-level loss + loss_c = self.compute_coarse_loss( + data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \ + else data['conf_matrix'], + data['conf_matrix_gt'], + weight=c_weight) + loss = loss_c * self.loss_config['coarse_weight'] + loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) + + # 2. fine-level loss + loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt']) + if loss_f is not None: + loss += loss_f * self.loss_config['fine_weight'] + loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) + else: + assert self.training is False + loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound + + # 3. flow loss + coarse_corr=[data['spv_b_ids'],data['spv_i_ids'],data['spv_j_ids']] + loss_flow = self.compute_flow_loss(coarse_corr,data['predict_flow'],\ + data['hw0_c'][0],data['hw0_c'][1],data['hw1_c'][0],data['hw1_c'][1]) + loss_flow=loss_flow*self.flow_weight + for index,loss_off in enumerate(loss_flow): + loss_scalars.update({'loss_flow_'+str(index): loss_off.clone().detach().cpu()}) # 1 is the upper bound + conf=data['predict_flow'][0][:,:,:,:,2:] + layer_num=conf.shape[0] + for layer_index in range(layer_num): + loss_scalars.update({'conf_'+str(layer_index): conf[layer_index].mean().clone().detach().cpu()}) # 1 is the upper bound + + + loss+=loss_flow.sum() + #print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data) + loss_scalars.update({'loss': loss.clone().detach().cpu()}) + data.update({"loss": loss, "loss_scalars": loss_scalars}) diff --git a/imcui/third_party/ASpanFormer/src/optimizers/__init__.py b/imcui/third_party/ASpanFormer/src/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/optimizers/__init__.py @@ -0,0 +1,42 @@ +import torch +from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR + + +def build_optimizer(model, config): + name = config.TRAINER.OPTIMIZER + lr = config.TRAINER.TRUE_LR + + if name == "adam": + return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) + elif name == "adamw": + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) + else: + raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") + + +def build_scheduler(config, optimizer): + """ + Returns: + scheduler (dict):{ + 'scheduler': lr_scheduler, + 'interval': 'step', # or 'epoch' + 'monitor': 'val_f1', (optional) + 'frequency': x, (optional) + } + """ + scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + name = config.TRAINER.SCHEDULER + + if name == 'MultiStepLR': + scheduler.update( + {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) + elif name == 'CosineAnnealing': + scheduler.update( + {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) + elif name == 'ExponentialLR': + scheduler.update( + {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + else: + raise NotImplementedError() + + return scheduler diff --git a/imcui/third_party/ASpanFormer/src/utils/augment.py b/imcui/third_party/ASpanFormer/src/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/utils/augment.py @@ -0,0 +1,55 @@ +import albumentations as A + + +class DarkAug(object): + """ + Extreme dark augmentation aiming at Aachen Day-Night + """ + + def __init__(self) -> None: + self.augmentor = A.Compose([ + A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) + ], p=0.75) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +class MobileAug(object): + """ + Random augmentations aiming at images of mobile/handhold devices. + """ + + def __init__(self): + self.augmentor = A.Compose([ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25) + ], p=1.0) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +def build_augmentor(method=None, **kwargs): + if method is not None: + raise NotImplementedError('Using of augmentation functions are not supported yet!') + if method == 'dark': + return DarkAug() + elif method == 'mobile': + return MobileAug() + elif method is None: + return None + else: + raise ValueError(f'Invalid augmentation method: {method}') + + +if __name__ == '__main__': + augmentor = build_augmentor('FDA') diff --git a/imcui/third_party/ASpanFormer/src/utils/comm.py b/imcui/third_party/ASpanFormer/src/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/utils/comm.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +[Copied from detectron2] +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/imcui/third_party/ASpanFormer/src/utils/dataloader.py b/imcui/third_party/ASpanFormer/src/utils/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/utils/dataloader.py @@ -0,0 +1,23 @@ +import numpy as np + + +# --- PL-DATAMODULE --- + +def get_local_split(items: list, world_size: int, rank: int, seed: int): + """ The local rank only loads a split of the dataset. """ + n_items = len(items) + items_permute = np.random.RandomState(seed).permutation(items) + if n_items % world_size == 0: + padded_items = items_permute + else: + padding = np.random.RandomState(seed).choice( + items, + world_size - (n_items % world_size), + replace=True) + padded_items = np.concatenate([items_permute, padding]) + assert len(padded_items) % world_size == 0, \ + f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' + n_per_rank = len(padded_items) // world_size + local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] + + return local_items diff --git a/imcui/third_party/ASpanFormer/src/utils/dataset.py b/imcui/third_party/ASpanFormer/src/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..209bf554acc20e33ea89eb9e7024ba68d0b3a30b --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/utils/dataset.py @@ -0,0 +1,222 @@ +import io +import cv2 +import numpy as np +import h5py +import torch +from numpy.linalg import inv +import re + + +try: + # for internel use only + from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT +except Exception: + MEGADEPTH_CLIENT = SCANNET_CLIENT = None + +# --- DATA IO --- + +def load_array_from_s3( + path, client, cv_type, + use_h5py=False, +): + byte_str = client.Get(path) + try: + if not use_h5py: + raw_array = np.fromstring(byte_str, np.uint8) + data = cv2.imdecode(raw_array, cv_type) + else: + f = io.BytesIO(byte_str) + data = np.array(h5py.File(f, 'r')['/depth']) + except Exception as ex: + print(f"==> Data loading failure: {path}") + raise ex + + assert data is not None + return data + + +def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ + else cv2.IMREAD_COLOR + if str(path).startswith('s3://'): + image = load_array_from_s3(str(path), client, cv_type) + else: + image = cv2.imread(str(path), cv_type) + + if augment_fn is not None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (h, w) + + +def get_resized_wh(w, h, resize=None): + if resize is not None: # resize the longer edge + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new, h_new = map(lambda x: int(x // df * df), [w, h]) + else: + w_new, h_new = w, h + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + elif inp.ndim == 3: + padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) + padded[:, :inp.shape[1], :inp.shape[2]] = inp + if ret_mask: + mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) + mask[:, :inp.shape[1], :inp.shape[2]] = True + else: + raise NotImplementedError() + return padded, mask + + +# --- MEGADEPTH --- + +def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) + + # resize image + w, h = image.shape[1], image.shape[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + if mask is not None: + mask = torch.from_numpy(mask) + + return image, mask, scale + + +def read_megadepth_depth(path, pad_to=None): + if str(path).startswith('s3://'): + depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) + else: + depth = np.array(h5py.File(path, 'r')['depth']) + if pad_to is not None: + depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +# --- ScanNet --- + +def read_scannet_gray(path, resize=(640, 480), augment_fn=None): + """ + Args: + resize (tuple): align image to depthmap, in (w, h). + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read and resize image + image = imread_gray(path, augment_fn) + image = cv2.resize(image, resize) + + # (h, w) -> (1, h, w) and normalized + image = torch.from_numpy(image).float()[None] / 255 + return image + + +def read_scannet_depth(path): + if str(path).startswith('s3://'): + depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) + else: + depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +def read_scannet_pose(path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = inv(cam2world) + return world2cam + + +def read_scannet_intrinsic(path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return intrinsic[:-1, :-1] + + +def read_gl3d_gray(path,resize): + img=cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),(int(resize),int(resize))) + img = torch.from_numpy(img).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + return img + +def read_gl3d_depth(file_path): + with open(file_path, 'rb') as fin: + color = None + width = None + height = None + scale = None + data_type = None + header = str(fin.readline().decode('UTF-8')).rstrip() + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8')) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + scale = float((fin.readline().decode('UTF-8')).rstrip()) + if scale < 0: # little-endian + data_type = ' best_num_inliers: + ret = (R, t[:, 0], mask.ravel() > 0) + best_num_inliers = n + + return ret + + +def compute_pose_errors(data, config): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] + "t_errs" List[float]: [N] + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 + conf = config.TRAINER.RANSAC_CONF # 0.99999 + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + K0 = data['K0'].cpu().numpy() + K1 = data['K1'].cpu().numpy() + T_0to1 = data['T_0to1'].cpu().numpy() + + for bs in range(K0.shape[0]): + mask = m_bids == bs + ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) + + if ret is None: + data['R_errs'].append(np.inf) + data['t_errs'].append(np.inf) + data['inliers'].append(np.array([]).astype(np.bool)) + else: + R, t, inliers = ret + t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) + data['R_errs'].append(R_err) + data['t_errs'].append(t_err) + data['inliers'].append(inliers) + + +# --- METRIC AGGREGATION --- + +def error_auc(errors, thresholds): + """ + Args: + errors (list): [N,] + thresholds (list) + """ + errors = [0] + sorted(list(errors)) + recall = list(np.linspace(0, 1, len(errors))) + + aucs = [] + thresholds = [5, 10, 20] + for thr in thresholds: + last_index = np.searchsorted(errors, thr) + y = recall[:last_index] + [recall[last_index-1]] + x = errors[:last_index] + [thr] + aucs.append(np.trapz(y, x) / thr) + + return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + + +def epidist_prec(errors, thresholds, ret_dict=False,offset=False): + precs = [] + for thr in thresholds: + prec_ = [] + for errs in errors: + correct_mask = errs < thr + prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) + precs.append(np.mean(prec_) if len(prec_) > 0 else 0) + if ret_dict: + return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} if not offset else {f'prec_flow@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + else: + return precs + + +def aggregate_metrics(metrics, epi_err_thr=5e-4): + """ Aggregate metrics for the whole dataset: + (This method should be called once per dataset) + 1. AUC of the pose error (angular) at the threshold [5, 10, 20] + 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) + """ + # filter duplicates + unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) + unq_ids = list(unq_ids.values()) + logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') + + # pose auc + angular_thresholds = [5, 10, 20] + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) + + # matching precision + dist_thresholds = [epi_err_thr] + precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) + + #offset precision + try: + precs_offset = epidist_prec(np.array(metrics['epi_errs_offset'], dtype=object)[unq_ids], [2e-3], True,offset=True) + return {**aucs, **precs,**precs_offset} + except: + return {**aucs, **precs} diff --git a/imcui/third_party/ASpanFormer/src/utils/misc.py b/imcui/third_party/ASpanFormer/src/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..25e4433f5ffa41adc4c0435cfe2b5696e43b58b3 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/utils/misc.py @@ -0,0 +1,139 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only +import cv2 +import numpy as np + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() + + +def draw_points(img,points,color=(0,255,0),radius=3): + dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])] + for i in range(points.shape[0]): + cv2.circle(img, dp[i],radius=radius,color=color) + return img + + +def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None): + if resize is not None: + scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]] + img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) + corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis] + corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])] + corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])] + + assert len(corr1) == len(corr2) + + draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))] + if color is None: + color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier] + if len(color)==1: + display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None, + matchColor=color[0], + singlePointColor=color[0], + flags=4 + ) + else: + height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1] + display=np.zeros([height,width,3],np.uint8) + display[:img1.shape[0],:img1.shape[1]]=img1 + display[:img2.shape[0],img1.shape[1]:]=img2 + for i in range(len(corr1)): + left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1]) + cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2])) + cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA) + return display diff --git a/imcui/third_party/ASpanFormer/src/utils/plotting.py b/imcui/third_party/ASpanFormer/src/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..8696880237b6ad9fe48d3c1fc44ed13b691a6c4d --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/utils/plotting.py @@ -0,0 +1,219 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +from copy import deepcopy + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth' or dataset_name=='gl3d': + thr = 1e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic'): + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_evaluation_figure_offset(data, b_id, alpha='dynamic',side=''): + layer_num=data['predict_flow'][0].shape[0] + + b_mask = data['offset_bids'+side] == b_id + conf_thr = 2e-3 #hardcode for scannet(coarse level) + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + + figure_list=[] + #draw offset matches in different layers + for layer_index in range(layer_num): + l_mask=data['offset_lids'+side]==layer_index + mask=l_mask&b_mask + kpts0 = data['offset_kpts0_f'+side][mask].cpu().numpy() + kpts1 = data['offset_kpts1_f'+side][mask].cpu().numpy() + + epi_errs = data['epi_errs_offset'+side][mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + #import pdb;pdb.set_trace() + figure = make_matching_figure(deepcopy(img0), deepcopy(img1) , kpts0, kpts1, + color, text=text) + figure_list.append(figure) + return figure + +def _make_confidence_figure(data, b_id): + # TODO: Implement confidence figure + raise NotImplementedError() + + +def make_matching_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + +def make_matching_figures_offset(data, config, mode='evaluation',side=''): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure_offset( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA,side=side) + elif mode == 'confidence': + fig = _make_evaluation_figure_offset(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) diff --git a/imcui/third_party/ASpanFormer/src/utils/profiler.py b/imcui/third_party/ASpanFormer/src/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9 --- /dev/null +++ b/imcui/third_party/ASpanFormer/src/utils/profiler.py @@ -0,0 +1,39 @@ +import torch +from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler +from contextlib import contextmanager +from pytorch_lightning.utilities import rank_zero_only + + +class InferenceProfiler(SimpleProfiler): + """ + This profiler records duration of actions with cuda.synchronize() + Use this in test time. + """ + + def __init__(self): + super().__init__() + self.start = rank_zero_only(self.start) + self.stop = rank_zero_only(self.stop) + self.summary = rank_zero_only(self.summary) + + @contextmanager + def profile(self, action_name: str) -> None: + try: + torch.cuda.synchronize() + self.start(action_name) + yield action_name + finally: + torch.cuda.synchronize() + self.stop(action_name) + + +def build_profiler(name): + if name == 'inference': + return InferenceProfiler() + elif name == 'pytorch': + from pytorch_lightning.profiler import PyTorchProfiler + return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) + elif name is None: + return PassThroughProfiler() + else: + raise ValueError(f'Invalid profiler: {name}') diff --git a/imcui/third_party/ASpanFormer/test.py b/imcui/third_party/ASpanFormer/test.py new file mode 100644 index 0000000000000000000000000000000000000000..541ce84662ab4888c6fece30403c5c9983118637 --- /dev/null +++ b/imcui/third_party/ASpanFormer/test.py @@ -0,0 +1,69 @@ +import pytorch_lightning as pl +import argparse +import pprint +from loguru import logger as loguru_logger + +from src.config.default import get_cfg_defaults +from src.utils.profiler import build_profiler + +from src.lightning.data import MultiSceneDataModule +from src.lightning.lightning_aspanformer import PL_ASpanFormer +import torch + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') + parser.add_argument( + '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") + parser.add_argument( + '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--batch_size', type=int, default=1, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=2) + parser.add_argument( + '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') + parser.add_argument( + '--mode', type=str, default='vanilla', help='modify the coarse-level matching threshold.') + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +if __name__ == '__main__': + # parse arguments + args = parse_args() + pprint.pprint(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + pl.seed_everything(config.TRAINER.SEED) # reproducibility + + # tune when testing + if args.thr is not None: + config.ASPAN.MATCH_COARSE.THR = args.thr + + loguru_logger.info(f"Args and config initialized!") + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_ASpanFormer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) + loguru_logger.info(f"ASpanFormer-lightning initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"DataModule initialized!") + + # lightning trainer + trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) + + loguru_logger.info(f"Start testing!") + trainer.test(model, datamodule=data_module, verbose=False) diff --git a/imcui/third_party/ASpanFormer/tools/SensorData.py b/imcui/third_party/ASpanFormer/tools/SensorData.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ec2644bf8b3b988ef0f36851cd3317c00511b2 --- /dev/null +++ b/imcui/third_party/ASpanFormer/tools/SensorData.py @@ -0,0 +1,125 @@ + +import os, struct +import numpy as np +import zlib +import imageio +import cv2 +import png + +COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'} +COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'} + +class RGBDFrame(): + + def load(self, file_handle): + self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4) + self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] + self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] + self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] + self.color_data = ''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes))) + self.depth_data = ''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes))) + + + def decompress_depth(self, compression_type): + if compression_type == 'zlib_ushort': + return self.decompress_depth_zlib() + else: + raise + + + def decompress_depth_zlib(self): + return zlib.decompress(self.depth_data) + + + def decompress_color(self, compression_type): + if compression_type == 'jpeg': + return self.decompress_color_jpeg() + else: + raise + + + def decompress_color_jpeg(self): + return imageio.imread(self.color_data) + + +class SensorData: + + def __init__(self, filename): + self.version = 4 + self.load(filename) + + + def load(self, filename): + with open(filename, 'rb') as f: + version = struct.unpack('I', f.read(4))[0] + assert self.version == version + strlen = struct.unpack('Q', f.read(8))[0] + self.sensor_name = ''.join(struct.unpack('c'*strlen, f.read(strlen))) + self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) + self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]] + self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]] + self.color_width = struct.unpack('I', f.read(4))[0] + self.color_height = struct.unpack('I', f.read(4))[0] + self.depth_width = struct.unpack('I', f.read(4))[0] + self.depth_height = struct.unpack('I', f.read(4))[0] + self.depth_shift = struct.unpack('f', f.read(4))[0] + num_frames = struct.unpack('Q', f.read(8))[0] + self.frames = [] + for i in range(num_frames): + frame = RGBDFrame() + frame.load(f) + self.frames.append(frame) + + + def export_depth_images(self, output_path, image_size=None, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print 'exporting', len(self.frames)//frame_skip, ' depth frames to', output_path + for f in range(0, len(self.frames), frame_skip): + depth_data = self.frames[f].decompress_depth(self.depth_compression_type) + depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) + if image_size is not None: + depth = cv2.resize(depth, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) + #imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth) + with open(os.path.join(output_path, str(f) + '.png'), 'wb') as f: # write 16-bit + writer = png.Writer(width=depth.shape[1], height=depth.shape[0], bitdepth=16) + depth = depth.reshape(-1, depth.shape[1]).tolist() + writer.write(f, depth) + + def export_color_images(self, output_path, image_size=None, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print 'exporting', len(self.frames)//frame_skip, 'color frames to', output_path + for f in range(0, len(self.frames), frame_skip): + color = self.frames[f].decompress_color(self.color_compression_type) + if image_size is not None: + color = cv2.resize(color, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) + imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color) + + + def save_mat_to_file(self, matrix, filename): + with open(filename, 'w') as f: + for line in matrix: + np.savetxt(f, line[np.newaxis], fmt='%f') + + + def export_poses(self, output_path, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print 'exporting', len(self.frames)//frame_skip, 'camera poses to', output_path + for f in range(0, len(self.frames), frame_skip): + self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, str(f) + '.txt')) + + + def export_intrinsics(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + print 'exporting camera intrinsics to', output_path + self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt')) + self.save_mat_to_file(self.extrinsic_color, os.path.join(output_path, 'extrinsic_color.txt')) + self.save_mat_to_file(self.intrinsic_depth, os.path.join(output_path, 'intrinsic_depth.txt')) + self.save_mat_to_file(self.extrinsic_depth, os.path.join(output_path, 'extrinsic_depth.txt')) \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/tools/extract.py b/imcui/third_party/ASpanFormer/tools/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..12f55e2f94120d5765f124f8eec867f1d82e0aa7 --- /dev/null +++ b/imcui/third_party/ASpanFormer/tools/extract.py @@ -0,0 +1,47 @@ +import os +import glob +from re import split +from tqdm import tqdm +from multiprocessing import Pool +from functools import partial + +scannet_dir='/root/data/ScanNet-v2-1.0.0/data/raw' +dump_dir='/root/data/scannet_dump' +num_process=32 + +def extract(seq,scannet_dir,split,dump_dir): + assert split=='train' or split=='test' + if not os.path.exists(os.path.join(dump_dir,split,seq)): + os.mkdir(os.path.join(dump_dir,split,seq)) + cmd='python reader.py --filename '+os.path.join(scannet_dir,'scans' if split=='train' else 'scans_test',seq,seq+'.sens')+' --output_path '+os.path.join(dump_dir,split,seq)+\ + ' --export_depth_images --export_color_images --export_poses --export_intrinsics' + os.system(cmd) + +if __name__=='__main__': + if not os.path.exists(dump_dir): + os.mkdir(dump_dir) + os.mkdir(os.path.join(dump_dir,'train')) + os.mkdir(os.path.join(dump_dir,'test')) + + train_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans','scene*'))] + test_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans_test','scene*'))] + + extract_train=partial(extract,scannet_dir=scannet_dir,split='train',dump_dir=dump_dir) + extract_test=partial(extract,scannet_dir=scannet_dir,split='test',dump_dir=dump_dir) + + num_train_iter=len(train_seq_list)//num_process if len(train_seq_list)%num_process==0 else len(train_seq_list)//num_process+1 + num_test_iter=len(test_seq_list)//num_process if len(test_seq_list)%num_process==0 else len(test_seq_list)//num_process+1 + + pool = Pool(num_process) + for index in tqdm(range(num_train_iter)): + seq_list=train_seq_list[index*num_process:min((index+1)*num_process,len(train_seq_list))] + pool.map(extract_train,seq_list) + pool.close() + pool.join() + + pool = Pool(num_process) + for index in tqdm(range(num_test_iter)): + seq_list=test_seq_list[index*num_process:min((index+1)*num_process,len(test_seq_list))] + pool.map(extract_test,seq_list) + pool.close() + pool.join() \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/tools/preprocess_scene.py b/imcui/third_party/ASpanFormer/tools/preprocess_scene.py new file mode 100644 index 0000000000000000000000000000000000000000..d20c0d070243519d67bbd25668ff5eb1657474be --- /dev/null +++ b/imcui/third_party/ASpanFormer/tools/preprocess_scene.py @@ -0,0 +1,242 @@ +import argparse + +import imagesize + +import numpy as np + +import os + +parser = argparse.ArgumentParser(description='MegaDepth preprocessing script') + +parser.add_argument( + '--base_path', type=str, required=True, + help='path to MegaDepth' +) +parser.add_argument( + '--scene_id', type=str, required=True, + help='scene ID' +) + +parser.add_argument( + '--output_path', type=str, required=True, + help='path to the output directory' +) + +args = parser.parse_args() + +base_path = args.base_path +# Remove the trailing / if need be. +if base_path[-1] in ['/', '\\']: + base_path = base_path[: - 1] +scene_id = args.scene_id + +base_depth_path = os.path.join( + base_path, 'phoenix/S6/zl548/MegaDepth_v1' +) +base_undistorted_sfm_path = os.path.join( + base_path, 'Undistorted_SfM' +) + +undistorted_sparse_path = os.path.join( + base_undistorted_sfm_path, scene_id, 'sparse-txt' +) +if not os.path.exists(undistorted_sparse_path): + exit() + +depths_path = os.path.join( + base_depth_path, scene_id, 'dense0', 'depths' +) +if not os.path.exists(depths_path): + exit() + +images_path = os.path.join( + base_undistorted_sfm_path, scene_id, 'images' +) +if not os.path.exists(images_path): + exit() + +# Process cameras.txt +with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f: + raw = f.readlines()[3 :] # skip the header + +camera_intrinsics = {} +for camera in raw: + camera = camera.split(' ') + camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]] + +# Process points3D.txt +with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f: + raw = f.readlines()[3 :] # skip the header + +points3D = {} +for point3D in raw: + point3D = point3D.split(' ') + points3D[int(point3D[0])] = np.array([ + float(point3D[1]), float(point3D[2]), float(point3D[3]) + ]) + +# Process images.txt +with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f: + raw = f.readlines()[4 :] # skip the header + +image_id_to_idx = {} +image_names = [] +raw_pose = [] +camera = [] +points3D_id_to_2D = [] +n_points3D = [] +for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])): + image = image.split(' ') + points = points.split(' ') + + image_id_to_idx[int(image[0])] = idx + + image_name = image[-1].strip('\n') + image_names.append(image_name) + + raw_pose.append([float(elem) for elem in image[1 : -2]]) + camera.append(int(image[-2])) + current_points3D_id_to_2D = {} + for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]): + if int(point3D_id) == -1: + continue + current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)] + points3D_id_to_2D.append(current_points3D_id_to_2D) + n_points3D.append(len(current_points3D_id_to_2D)) +n_images = len(image_names) + +# Image and depthmaps paths +image_paths = [] +depth_paths = [] +for image_name in image_names: + image_path = os.path.join(images_path, image_name) + + # Path to the depth file + depth_path = os.path.join( + depths_path, '%s.h5' % os.path.splitext(image_name)[0] + ) + + if os.path.exists(depth_path): + # Check if depth map or background / foreground mask + file_size = os.stat(depth_path).st_size + # Rough estimate - 75KB might work as well + if file_size < 100 * 1024: + depth_paths.append(None) + image_paths.append(None) + else: + depth_paths.append(depth_path[len(base_path) + 1 :]) + image_paths.append(image_path[len(base_path) + 1 :]) + else: + depth_paths.append(None) + image_paths.append(None) + +# Camera configuration +intrinsics = [] +poses = [] +principal_axis = [] +points3D_id_to_ndepth = [] +for idx, image_name in enumerate(image_names): + if image_paths[idx] is None: + intrinsics.append(None) + poses.append(None) + principal_axis.append([0, 0, 0]) + points3D_id_to_ndepth.append({}) + continue + image_intrinsics = camera_intrinsics[camera[idx]] + K = np.zeros([3, 3]) + K[0, 0] = image_intrinsics[2] + K[0, 2] = image_intrinsics[4] + K[1, 1] = image_intrinsics[3] + K[1, 2] = image_intrinsics[5] + K[2, 2] = 1 + intrinsics.append(K) + + image_pose = raw_pose[idx] + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array([ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ] + ]) + principal_axis.append(R[2, :]) + t = image_pose[4 : 7] + # World-to-Camera pose + current_pose = np.zeros([4, 4]) + current_pose[: 3, : 3] = R + current_pose[: 3, 3] = t + current_pose[3, 3] = 1 + # Camera-to-World pose + # pose = np.zeros([4, 4]) + # pose[: 3, : 3] = np.transpose(R) + # pose[: 3, 3] = -np.matmul(np.transpose(R), t) + # pose[3, 3] = 1 + poses.append(current_pose) + + current_points3D_id_to_ndepth = {} + for point3D_id in points3D_id_to_2D[idx].keys(): + p3d = points3D[point3D_id] + current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1])) + points3D_id_to_ndepth.append(current_points3D_id_to_ndepth) +principal_axis = np.array(principal_axis) +angles = np.rad2deg(np.arccos( + np.clip( + np.dot(principal_axis, np.transpose(principal_axis)), + -1, 1 + ) +)) + +# Compute overlap score +overlap_matrix = np.full([n_images, n_images], -1.) +scale_ratio_matrix = np.full([n_images, n_images], -1.) +for idx1 in range(n_images): + if image_paths[idx1] is None or depth_paths[idx1] is None: + continue + for idx2 in range(idx1 + 1, n_images): + if image_paths[idx2] is None or depth_paths[idx2] is None: + continue + matches = ( + points3D_id_to_2D[idx1].keys() & + points3D_id_to_2D[idx2].keys() + ) + min_num_points3D = min( + len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2]) + ) + overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D + overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D + if len(matches) == 0: + continue + points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1] + points3D_id_to_ndepth2 = points3D_id_to_ndepth[idx2] + nd1 = np.array([points3D_id_to_ndepth1[match] for match in matches]) + nd2 = np.array([points3D_id_to_ndepth2[match] for match in matches]) + min_scale_ratio = np.min(np.maximum(nd1 / nd2, nd2 / nd1)) + scale_ratio_matrix[idx1, idx2] = min_scale_ratio + scale_ratio_matrix[idx2, idx1] = min_scale_ratio + +np.savez( + os.path.join(args.output_path, '%s.npz' % scene_id), + image_paths=image_paths, + depth_paths=depth_paths, + intrinsics=intrinsics, + poses=poses, + overlap_matrix=overlap_matrix, + scale_ratio_matrix=scale_ratio_matrix, + angles=angles, + n_points3D=n_points3D, + points3D_id_to_2D=points3D_id_to_2D, + points3D_id_to_ndepth=points3D_id_to_ndepth +) \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/tools/reader.py b/imcui/third_party/ASpanFormer/tools/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..f419fbaa8a099fcfede1cea51fcf95a2c1589160 --- /dev/null +++ b/imcui/third_party/ASpanFormer/tools/reader.py @@ -0,0 +1,39 @@ +import argparse +import os, sys + +from SensorData import SensorData + +# params +parser = argparse.ArgumentParser() +# data paths +parser.add_argument('--filename', required=True, help='path to sens file to read') +parser.add_argument('--output_path', required=True, help='path to output folder') +parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true') +parser.add_argument('--export_color_images', dest='export_color_images', action='store_true') +parser.add_argument('--export_poses', dest='export_poses', action='store_true') +parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true') +parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False) + +opt = parser.parse_args() +print(opt) + + +def main(): + if not os.path.exists(opt.output_path): + os.makedirs(opt.output_path) + # load the data + sys.stdout.write('loading %s...' % opt.filename) + sd = SensorData(opt.filename) + sys.stdout.write('loaded!\n') + if opt.export_depth_images: + sd.export_depth_images(os.path.join(opt.output_path, 'depth')) + if opt.export_color_images: + sd.export_color_images(os.path.join(opt.output_path, 'color')) + if opt.export_poses: + sd.export_poses(os.path.join(opt.output_path, 'pose')) + if opt.export_intrinsics: + sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic')) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/tools/undistort_mega.py b/imcui/third_party/ASpanFormer/tools/undistort_mega.py new file mode 100644 index 0000000000000000000000000000000000000000..68798ff30e6afa37a0f98571ecfd3f05751868c8 --- /dev/null +++ b/imcui/third_party/ASpanFormer/tools/undistort_mega.py @@ -0,0 +1,69 @@ +import argparse + +import imagesize + +import os + +import subprocess + +parser = argparse.ArgumentParser(description='MegaDepth Undistortion') + +parser.add_argument( + '--colmap_path', type=str,default='/usr/bin/', + help='path to colmap executable' +) +parser.add_argument( + '--base_path', type=str,default='/root/MegaDepth', + help='path to MegaDepth' +) + +args = parser.parse_args() + +sfm_path = os.path.join( + args.base_path, 'MegaDepth_v1_SfM' +) +base_depth_path = os.path.join( + args.base_path, 'phoenix/S6/zl548/MegaDepth_v1' +) +output_path = os.path.join( + args.base_path, 'Undistorted_SfM' +) + +os.mkdir(output_path) + +for scene_name in os.listdir(base_depth_path): + current_output_path = os.path.join(output_path, scene_name) + os.mkdir(current_output_path) + + image_path = os.path.join( + base_depth_path, scene_name, 'dense0', 'imgs' + ) + if not os.path.exists(image_path): + continue + + # Find the maximum image size in scene. + max_image_size = 0 + for image_name in os.listdir(image_path): + max_image_size = max( + max_image_size, + max(imagesize.get(os.path.join(image_path, image_name))) + ) + + # Undistort the images and update the reconstruction. + subprocess.call([ + os.path.join(args.colmap_path, 'colmap'), 'image_undistorter', + '--image_path', os.path.join(sfm_path, scene_name, 'images'), + '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'), + '--output_path', current_output_path, + '--max_image_size', str(max_image_size) + ]) + + # Transform the reconstruction to raw text format. + sparse_txt_path = os.path.join(current_output_path, 'sparse-txt') + os.mkdir(sparse_txt_path) + subprocess.call([ + os.path.join(args.colmap_path, 'colmap'), 'model_converter', + '--input_path', os.path.join(current_output_path, 'sparse'), + '--output_path', sparse_txt_path, + '--output_type', 'TXT' + ]) \ No newline at end of file diff --git a/imcui/third_party/ASpanFormer/train.py b/imcui/third_party/ASpanFormer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..21f644763711481e84863ed5d861ec57d95f2d5c --- /dev/null +++ b/imcui/third_party/ASpanFormer/train.py @@ -0,0 +1,134 @@ +import math +import argparse +import pprint +from distutils.util import strtobool +from pathlib import Path +from loguru import logger as loguru_logger + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.plugins import DDPPlugin + +from src.config.default import get_cfg_defaults +from src.utils.misc import get_rank_zero_only_logger, setup_gpus +from src.utils.profiler import build_profiler +from src.lightning.data import MultiSceneDataModule +from src.lightning.lightning_aspanformer import PL_ASpanFormer + +loguru_logger = get_rank_zero_only_logger(loguru_logger) + + +def parse_args(): + def str2bool(v): + return v.lower() in ("true", "1") + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--exp_name', type=str, default='default_exp_name') + parser.add_argument( + '--batch_size', type=int, default=4, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=4) + parser.add_argument( + '--pin_memory', type=lambda x: bool(strtobool(x)), + nargs='?', default=True, help='whether loading data to pinned memory or not') + parser.add_argument( + '--ckpt_path', type=str, default=None, + help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer') + parser.add_argument( + '--disable_ckpt', action='store_true', + help='disable checkpoint saving (useful for debugging).') + parser.add_argument( + '--profiler_name', type=str, default=None, + help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--parallel_load_data', action='store_true', + help='load datasets in with multiple processes.') + parser.add_argument( + '--mode', type=str, default='vanilla', + help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer') + parser.add_argument( + '--ini', type=str2bool, default=False, + help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +def main(): + # parse arguments + args = parse_args() + rank_zero_only(pprint.pprint)(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + pl.seed_everything(config.TRAINER.SEED) # reproducibility + # TODO: Use different seeds for each dataloader workers + # This is needed for data augmentation + + # scale lr and warmup-step automatically + args.gpus = _n_gpus = setup_gpus(args.gpus) + config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes + config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size + _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS + config.TRAINER.SCALING = _scaling + config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling + config.TRAINER.WARMUP_STEP = math.floor( + config.TRAINER.WARMUP_STEP / _scaling) + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_ASpanFormer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) + loguru_logger.info(f"ASpanFormer LightningModule initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"ASpanFormer DataModule initialized!") + + # TensorBoard Logger + logger = TensorBoardLogger( + save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False) + ckpt_dir = Path(logger.log_dir) / 'checkpoints' + + # Callbacks + # TODO: update ModelCheckpoint to monitor multiple metrics + ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max', + save_last=True, + dirpath=str(ckpt_dir), + filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}') + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks = [lr_monitor] + if not args.disable_ckpt: + callbacks.append(ckpt_callback) + + # Lightning Trainer + trainer = pl.Trainer.from_argparse_args( + args, + plugins=DDPPlugin(find_unused_parameters=False, + num_nodes=args.num_nodes, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), + gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, + callbacks=callbacks, + logger=logger, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, + replace_sampler_ddp=False, # use custom sampler + reload_dataloaders_every_epoch=False, # avoid repeated samples! + weights_summary='full', + profiler=profiler) + loguru_logger.info(f"Trainer initialized!") + loguru_logger.info(f"Start training!") + trainer.fit(model, datamodule=data_module) + + +if __name__ == '__main__': + main() diff --git a/imcui/third_party/COTR/COTR/cameras/camera_pose.py b/imcui/third_party/COTR/COTR/cameras/camera_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd2263cbaace94036c3ef85eb082ca749ba51eb --- /dev/null +++ b/imcui/third_party/COTR/COTR/cameras/camera_pose.py @@ -0,0 +1,164 @@ +''' +Extrinsic camera pose +''' +import math +import copy + +import numpy as np + +from COTR.transformations import transformations +from COTR.transformations.transform_basics import Translation, Rotation, UnstableRotation + + +class CameraPose(): + def __init__(self, t: Translation, r: Rotation): + ''' + WARN: World 2 cam + Translation and rotation are world to camera + translation_vector is not the coordinate of the camera in world space. + ''' + assert isinstance(t, Translation) + assert isinstance(r, Rotation) or isinstance(r, UnstableRotation) + self.t = t + self.r = r + + def __str__(self): + string = f'center in world: {self.camera_center_in_world}, translation(w2c): {self.t}, rotation(w2c): {self.r}' + return string + + @classmethod + def from_world_to_camera(cls, world_to_camera, unstable=False): + assert isinstance(world_to_camera, np.ndarray) + assert world_to_camera.shape == (4, 4) + vec = transformations.translation_from_matrix(world_to_camera).astype(np.float32) + t = Translation(vec) + if unstable: + r = UnstableRotation(world_to_camera) + else: + quat = transformations.quaternion_from_matrix(world_to_camera).astype(np.float32) + r = Rotation(quat) + return cls(t, r) + + @classmethod + def from_camera_to_world(cls, camera_to_world, unstable=False): + assert isinstance(camera_to_world, np.ndarray) + assert camera_to_world.shape == (4, 4) + world_to_camera = np.linalg.inv(camera_to_world) + world_to_camera /= world_to_camera[3, 3] + return cls.from_world_to_camera(world_to_camera, unstable) + + @classmethod + def from_pose_vector(cls, pose_vector): + t = Translation(pose_vector[:3]) + r = Rotation(pose_vector[3:]) + return cls(t, r) + + @property + def translation_vector(self): + return self.t.translation_vector + + @property + def translation_matrix(self): + return self.t.translation_matrix + + @property + def quaternion(self): + ''' + quaternion format (w, x, y, z) + ''' + return self.r.quaternion + + @property + def rotation_matrix(self): + return self.r.rotation_matrix + + @property + def pose_vector(self): + ''' + Pose vector is a concat of translation vector and quaternion vector + (X, Y, Z, w, x, y, z) + w2c + ''' + return np.concatenate([self.translation_vector, self.quaternion]) + + @property + def inv_pose_vector(self): + inv_quat = transformations.quaternion_inverse(self.quaternion) + return np.concatenate([self.camera_center_in_world, inv_quat]) + + @property + def pose_vector_6_dof(self): + ''' + Here we assuming the quaternion is normalized and we remove the W component + (X, Y, Z, x, y, z) + ''' + return np.concatenate([self.translation_vector, self.quaternion[1:]]) + + @property + def world_to_camera(self): + M = np.matmul(self.translation_matrix, self.rotation_matrix) + M /= M[3, 3] + return M + + @property + def world_to_camera_3x4(self): + M = self.world_to_camera + M = M[0:3, 0:4] + return M + + @property + def extrinsic_mat(self): + return self.world_to_camera_3x4 + + @property + def camera_to_world(self): + M = np.linalg.inv(self.world_to_camera) + M /= M[3, 3] + return M + + @property + def camera_to_world_3x4(self): + M = self.camera_to_world + M = M[0:3, 0:4] + return M + + @property + def camera_center_in_world(self): + return self.camera_to_world[:3, 3] + + @property + def forward(self): + return self.camera_to_world[:3, 2] + + @property + def up(self): + return self.camera_to_world[:3, 1] + + @property + def right(self): + return self.camera_to_world[:3, 0] + + @property + def essential_matrix(self): + E = np.cross(self.rotation_matrix[:3, :3], self.camera_center_in_world) + return E / np.linalg.norm(E) + + +def inverse_camera_pose(cam_pose: CameraPose): + return CameraPose.from_world_to_camera(np.linalg.inv(cam_pose.world_to_camera)) + + +def rotate_camera_pose(cam_pose, rot): + if rot == 0: + return copy.deepcopy(cam_pose) + else: + rot = rot / 180 * np.pi + sin_rot = np.sin(rot) + cos_rot = np.cos(rot) + + rot_mat = np.stack([np.stack([cos_rot, -sin_rot, 0, 0], axis=-1), + np.stack([sin_rot, cos_rot, 0, 0], axis=-1), + np.stack([0, 0, 1, 0], axis=-1), + np.stack([0, 0, 0, 1], axis=-1)], axis=1) + new_world2cam = np.matmul(rot_mat, cam_pose.world_to_camera) + return CameraPose.from_world_to_camera(new_world2cam) diff --git a/imcui/third_party/COTR/COTR/cameras/capture.py b/imcui/third_party/COTR/COTR/cameras/capture.py new file mode 100644 index 0000000000000000000000000000000000000000..c09180def5fcb030e5cde0e14ab85b71831642ae --- /dev/null +++ b/imcui/third_party/COTR/COTR/cameras/capture.py @@ -0,0 +1,432 @@ +''' +Capture from a pinhole camera +Separate the captured content and the camera... +''' + +import os +import time +import abc +import copy + +import cv2 +import torch +import numpy as np +import imageio +import PIL +from PIL import Image + +from COTR.cameras.camera_pose import CameraPose, rotate_camera_pose +from COTR.cameras.pinhole_camera import PinholeCamera, rotate_pinhole_camera, crop_pinhole_camera +from COTR.utils import debug_utils, utils, constants +from COTR.utils.utils import Point2D +from COTR.projector import pcd_projector +from COTR.utils.constants import MAX_SIZE +from COTR.utils.utils import CropCamConfig + + +def crop_center_max_xy(p2d, shape): + h, w = shape + crop_x = min(h, w) + crop_y = crop_x + start_x = w // 2 - crop_x // 2 + start_y = h // 2 - crop_y // 2 + mask = (p2d.xy[:, 0] > start_x) & (p2d.xy[:, 0] < start_x + crop_x) & (p2d.xy[:, 1] > start_y) & (p2d.xy[:, 1] < start_y + crop_y) + out_xy = (p2d.xy - [start_x, start_y])[mask] + out = Point2D(p2d.id_3d[mask], out_xy) + return out + + +def crop_center_max(img): + if isinstance(img, torch.Tensor): + return crop_center_max_torch(img) + elif isinstance(img, np.ndarray): + return crop_center_max_np(img) + else: + raise ValueError + + +def crop_center_max_torch(img): + if len(img.shape) == 2: + h, w = img.shape + elif len(img.shape) == 3: + c, h, w = img.shape + elif len(img.shape) == 4: + b, c, h, w = img.shape + else: + raise ValueError + crop_x = min(h, w) + crop_y = crop_x + start_x = w // 2 - crop_x // 2 + start_y = h // 2 - crop_y // 2 + if len(img.shape) == 2: + return img[start_y:start_y + crop_y, start_x:start_x + crop_x] + elif len(img.shape) in [3, 4]: + return img[..., start_y:start_y + crop_y, start_x:start_x + crop_x] + + +def crop_center_max_np(img, return_starts=False): + if len(img.shape) == 2: + h, w = img.shape + elif len(img.shape) == 3: + h, w, c = img.shape + elif len(img.shape) == 4: + b, h, w, c = img.shape + else: + raise ValueError + crop_x = min(h, w) + crop_y = crop_x + start_x = w // 2 - crop_x // 2 + start_y = h // 2 - crop_y // 2 + if len(img.shape) == 2: + canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x] + elif len(img.shape) == 3: + canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x, :] + elif len(img.shape) == 4: + canvas = img[:, start_y:start_y + crop_y, start_x:start_x + crop_x, :] + if return_starts: + return canvas, -start_x, -start_y + else: + return canvas + + +def pad_to_square_np(img, till_divisible_by=1, return_starts=False): + if len(img.shape) == 2: + h, w = img.shape + elif len(img.shape) == 3: + h, w, c = img.shape + elif len(img.shape) == 4: + b, h, w, c = img.shape + else: + raise ValueError + if till_divisible_by == 1: + size = max(h, w) + else: + size = (max(h, w) + till_divisible_by) - (max(h, w) % till_divisible_by) + start_x = size // 2 - w // 2 + start_y = size // 2 - h // 2 + if len(img.shape) == 2: + canvas = np.zeros([size, size], dtype=img.dtype) + canvas[start_y:start_y + h, start_x:start_x + w] = img + elif len(img.shape) == 3: + canvas = np.zeros([size, size, c], dtype=img.dtype) + canvas[start_y:start_y + h, start_x:start_x + w, :] = img + elif len(img.shape) == 4: + canvas = np.zeros([b, size, size, c], dtype=img.dtype) + canvas[:, start_y:start_y + h, start_x:start_x + w, :] = img + if return_starts: + return canvas, start_x, start_y + else: + return canvas + + +def stretch_to_square_np(img): + size = max(*img.shape[:2]) + return np.array(PIL.Image.fromarray(img).resize((size, size), resample=PIL.Image.BILINEAR)) + + +def rotate_image(image, angle, interpolation=cv2.INTER_LINEAR): + image_center = tuple(np.array(image.shape[1::-1]) / 2) + rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=interpolation) + return result + + +def read_array(path): + ''' + https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py + ''' + with open(path, "rb") as fid: + width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, + usecols=(0, 1, 2), dtype=int) + fid.seek(0) + num_delimiter = 0 + byte = fid.read(1) + while True: + if byte == b"&": + num_delimiter += 1 + if num_delimiter >= 3: + break + byte = fid.read(1) + array = np.fromfile(fid, np.float32) + array = array.reshape((width, height, channels), order="F") + return np.transpose(array, (1, 0, 2)).squeeze() + + +################ Content ################ + + +class CapturedContent(abc.ABC): + def __init__(self): + self._rotation = 0 + + @property + def rotation(self): + return self._rotation + + @rotation.setter + def rotation(self, rot): + self._rotation = rot + + +class CapturedImage(CapturedContent): + def __init__(self, img_path, crop_cam, pinhole_cam_before=None): + super(CapturedImage, self).__init__() + assert os.path.isfile(img_path), 'file does not exist: {0}'.format(img_path) + self.crop_cam = crop_cam + self._image = None + self.img_path = img_path + self.pinhole_cam_before = pinhole_cam_before + self._p2d = None + + def read_image_to_ram(self) -> int: + # raise NotImplementedError + assert self._image is None + _image = self.image + self._image = _image + return self._image.nbytes + + @property + def image(self): + if self._image is not None: + _image = self._image + else: + _image = imageio.imread(self.img_path, pilmode='RGB') + if self.rotation != 0: + _image = rotate_image(_image, self.rotation) + if _image.shape[:2] != self.pinhole_cam_before.shape: + _image = np.array(PIL.Image.fromarray(_image).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.BILINEAR)) + assert _image.shape[:2] == self.pinhole_cam_before.shape + if self.crop_cam == 'no_crop': + pass + elif self.crop_cam == 'crop_center': + _image = crop_center_max(_image) + elif self.crop_cam == 'crop_center_and_resize': + _image = crop_center_max(_image) + _image = np.array(PIL.Image.fromarray(_image).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR)) + elif isinstance(self.crop_cam, CropCamConfig): + assert _image.shape[0] == self.crop_cam.orig_h + assert _image.shape[1] == self.crop_cam.orig_w + _image = _image[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h, + self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ] + _image = np.array(PIL.Image.fromarray(_image).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.BILINEAR)) + assert _image.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w) + else: + raise ValueError() + return _image + + @property + def p2d(self): + if self._p2d is None: + return self._p2d + else: + _p2d = self._p2d + if self.crop_cam == 'no_crop': + pass + elif self.crop_cam == 'crop_center': + _p2d = crop_center_max_xy(_p2d, self.pinhole_cam_before.shape) + else: + raise ValueError() + return _p2d + + @p2d.setter + def p2d(self, value): + if value is not None: + assert isinstance(value, Point2D) + self._p2d = value + + +class CapturedDepth(CapturedContent): + def __init__(self, depth_path, crop_cam, pinhole_cam_before=None): + super(CapturedDepth, self).__init__() + if not depth_path.endswith('dummy'): + assert os.path.isfile(depth_path), 'file does not exist: {0}'.format(depth_path) + self.crop_cam = crop_cam + self._depth = None + self.depth_path = depth_path + self.pinhole_cam_before = pinhole_cam_before + + def read_depth(self): + import tables + if self.depth_path.endswith('dummy'): + image_path = self.depth_path[:-5] + w, h = Image.open(image_path).size + _depth = np.zeros([h, w], dtype=np.float32) + elif self.depth_path.endswith('.h5'): + depth_h5 = tables.open_file(self.depth_path, mode='r') + _depth = np.array(depth_h5.root.depth) + depth_h5.close() + else: + raise ValueError + return _depth.astype(np.float32) + + def read_depth_to_ram(self) -> int: + # raise NotImplementedError + assert self._depth is None + _depth = self.depth_map + self._depth = _depth + return self._depth.nbytes + + @property + def depth_map(self): + if self._depth is not None: + _depth = self._depth + else: + _depth = self.read_depth() + if self.rotation != 0: + _depth = rotate_image(_depth, self.rotation, interpolation=cv2.INTER_NEAREST) + if _depth.shape != self.pinhole_cam_before.shape: + _depth = np.array(PIL.Image.fromarray(_depth).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.NEAREST)) + assert _depth.shape[:2] == self.pinhole_cam_before.shape + if self.crop_cam == 'no_crop': + pass + elif self.crop_cam == 'crop_center': + _depth = crop_center_max(_depth) + elif self.crop_cam == 'crop_center_and_resize': + _depth = crop_center_max(_depth) + _depth = np.array(PIL.Image.fromarray(_depth).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.NEAREST)) + elif isinstance(self.crop_cam, CropCamConfig): + assert _depth.shape[0] == self.crop_cam.orig_h + assert _depth.shape[1] == self.crop_cam.orig_w + _depth = _depth[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h, + self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ] + _depth = np.array(PIL.Image.fromarray(_depth).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.NEAREST)) + assert _depth.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w) + else: + raise ValueError() + assert (_depth >= 0).all() + return _depth + + +################ Pinhole Capture ################ +class BasePinholeCapture(): + def __init__(self, pinhole_cam, cam_pose, crop_cam): + self.crop_cam = crop_cam + self.cam_pose = cam_pose + # modify the camera instrinsics + self.pinhole_cam = crop_pinhole_camera(pinhole_cam, crop_cam) + self.pinhole_cam_before = pinhole_cam + + def __str__(self): + string = 'pinhole camera: {0}\ncamera pose: {1}'.format(self.pinhole_cam, self.cam_pose) + return string + + @property + def intrinsic_mat(self): + return self.pinhole_cam.intrinsic_mat + + @property + def extrinsic_mat(self): + return self.cam_pose.extrinsic_mat + + @property + def shape(self): + return self.pinhole_cam.shape + + @property + def size(self): + return self.shape + + @property + def mvp_mat(self): + ''' + model-view-projection matrix (naming from opengl) + ''' + return np.matmul(self.pinhole_cam.intrinsic_mat, self.cam_pose.world_to_camera_3x4) + + +class RGBPinholeCapture(BasePinholeCapture): + def __init__(self, img_path, pinhole_cam, cam_pose, crop_cam): + BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam) + self.captured_image = CapturedImage(img_path, crop_cam, self.pinhole_cam_before) + + def read_image_to_ram(self) -> int: + return self.captured_image.read_image_to_ram() + + @property + def img_path(self): + return self.captured_image.img_path + + @property + def image(self): + _image = self.captured_image.image + assert _image.shape[0:2] == self.pinhole_cam.shape, 'image shape: {0}, pinhole camera: {1}'.format(_image.shape, self.pinhole_cam) + return _image + + @property + def seq_id(self): + return os.path.dirname(self.captured_image.img_path) + + @property + def p2d(self): + return self.captured_image.p2d + + @p2d.setter + def p2d(self, value): + self.captured_image.p2d = value + + +class DepthPinholeCapture(BasePinholeCapture): + def __init__(self, depth_path, pinhole_cam, cam_pose, crop_cam): + BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam) + self.captured_depth = CapturedDepth(depth_path, crop_cam, self.pinhole_cam_before) + + def read_depth_to_ram(self) -> int: + return self.captured_depth.read_depth_to_ram() + + @property + def depth_path(self): + return self.captured_depth.depth_path + + @property + def depth_map(self): + _depth = self.captured_depth.depth_map + # if self.pinhole_cam.shape != _depth.shape: + # _depth = misc.imresize(_depth, self.pinhole_cam.shape, interp='nearest', mode='F') + assert (_depth >= 0).all() + return _depth + + @property + def point_cloud_world(self): + return self.get_point_cloud_world_from_depth(feat_map=None) + + def get_point_cloud_world_from_depth(self, feat_map=None): + _pcd = pcd_projector.PointCloudProjector.img_2d_to_pcd_3d_np(self.depth_map, self.pinhole_cam.intrinsic_mat, img=feat_map, motion=self.cam_pose.camera_to_world).astype(constants.DEFAULT_PRECISION) + return _pcd + + +class RGBDPinholeCapture(RGBPinholeCapture, DepthPinholeCapture): + def __init__(self, img_path, depth_path, pinhole_cam, cam_pose, crop_cam): + RGBPinholeCapture.__init__(self, img_path, pinhole_cam, cam_pose, crop_cam) + DepthPinholeCapture.__init__(self, depth_path, pinhole_cam, cam_pose, crop_cam) + + @property + def point_cloud_w_rgb_world(self): + return self.get_point_cloud_world_from_depth(feat_map=self.image) + + +def rotate_capture(cap, rot): + if rot == 0: + return copy.deepcopy(cap) + else: + rot_pose = rotate_camera_pose(cap.cam_pose, rot) + rot_cap = copy.deepcopy(cap) + rot_cap.cam_pose = rot_pose + if hasattr(rot_cap, 'captured_image'): + rot_cap.captured_image.rotation = rot + if hasattr(rot_cap, 'captured_depth'): + rot_cap.captured_depth.rotation = rot + return rot_cap + + +def crop_capture(cap, crop_cam): + if isinstance(cap, RGBDPinholeCapture): + cropped_cap = RGBDPinholeCapture(cap.img_path, cap.depth_path, cap.pinhole_cam, cap.cam_pose, crop_cam) + elif isinstance(cap, RGBPinholeCapture): + cropped_cap = RGBPinholeCapture(cap.img_path, cap.pinhole_cam, cap.cam_pose, crop_cam) + else: + raise ValueError + if hasattr(cropped_cap, 'captured_image'): + cropped_cap.captured_image.rotation = cap.captured_image.rotation + if hasattr(cropped_cap, 'captured_depth'): + cropped_cap.captured_depth.rotation = cap.captured_depth.rotation + return cropped_cap diff --git a/imcui/third_party/COTR/COTR/cameras/pinhole_camera.py b/imcui/third_party/COTR/COTR/cameras/pinhole_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..2f06cc167ba4a56092c9bdc7e15dc77f245f5d79 --- /dev/null +++ b/imcui/third_party/COTR/COTR/cameras/pinhole_camera.py @@ -0,0 +1,73 @@ +""" +Static pinhole camera +""" + +import copy + +import numpy as np + +from COTR.utils import constants +from COTR.utils.constants import MAX_SIZE +from COTR.utils.utils import CropCamConfig + + +class PinholeCamera(): + def __init__(self, width, height, fx, fy, cx, cy): + self.width = int(width) + self.height = int(height) + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + + def __str__(self): + string = 'width: {0}, height: {1}, fx: {2}, fy: {3}, cx: {4}, cy: {5}'.format(self.width, self.height, self.fx, self.fy, self.cx, self.cy) + return string + + @property + def shape(self): + return (self.height, self.width) + + @property + def intrinsic_mat(self): + mat = np.array([[self.fx, 0.0, self.cx], + [0.0, self.fy, self.cy], + [0.0, 0.0, 1.0]], dtype=constants.DEFAULT_PRECISION) + return mat + + +def rotate_pinhole_camera(cam, rot): + assert 0, 'TODO: Camera should stay the same while rotation' + assert rot in [0, 90, 180, 270], 'only support 0/90/180/270 degrees rotation' + if rot in [0, 180]: + return copy.deepcopy(cam) + elif rot in [90, 270]: + return PinholeCamera(width=cam.height, height=cam.width, fx=cam.fy, fy=cam.fx, cx=cam.cy, cy=cam.cx) + else: + raise NotImplementedError + + +def crop_pinhole_camera(pinhole_cam, crop_cam): + if crop_cam == 'no_crop': + cropped_pinhole_cam = pinhole_cam + elif crop_cam == 'crop_center': + _h = _w = min(*pinhole_cam.shape) + _cx = _cy = _h / 2 + cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx, pinhole_cam.fy, _cx, _cy) + elif crop_cam == 'crop_center_and_resize': + _h = _w = MAX_SIZE + _cx = _cy = MAX_SIZE / 2 + scale = MAX_SIZE / min(*pinhole_cam.shape) + cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx * scale, pinhole_cam.fy * scale, _cx, _cy) + elif isinstance(crop_cam, CropCamConfig): + scale = crop_cam.out_h / crop_cam.h + cropped_pinhole_cam = PinholeCamera(crop_cam.out_w, + crop_cam.out_h, + pinhole_cam.fx * scale, + pinhole_cam.fy * scale, + (pinhole_cam.cx - crop_cam.x) * scale, + (pinhole_cam.cy - crop_cam.y) * scale + ) + else: + raise ValueError + return cropped_pinhole_cam diff --git a/imcui/third_party/COTR/COTR/datasets/colmap_helper.py b/imcui/third_party/COTR/COTR/datasets/colmap_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..d84ed8645ba156fce7069361609d8cd4f577cfe4 --- /dev/null +++ b/imcui/third_party/COTR/COTR/datasets/colmap_helper.py @@ -0,0 +1,312 @@ +import sys +assert sys.version_info >= (3, 7), 'ordered dict is required' +import os +import re +from collections import namedtuple +import json + +import numpy as np +from tqdm import tqdm + +from COTR.utils import debug_utils +from COTR.cameras.pinhole_camera import PinholeCamera +from COTR.cameras.camera_pose import CameraPose +from COTR.cameras.capture import RGBPinholeCapture, RGBDPinholeCapture +from COTR.cameras import capture +from COTR.transformations import transformations +from COTR.transformations.transform_basics import Translation, Rotation +from COTR.sfm_scenes import sfm_scenes +from COTR.global_configs import dataset_config +from COTR.utils.utils import Point2D, Point3D + +ImageMeta = namedtuple('ImageMeta', ['image_id', 'r', 't', 'camera_id', 'image_path', 'point3d_id', 'p2d']) +COVISIBILITY_CHECK = False +LOAD_PCD = False + + +class ColmapAsciiReader(): + def __init__(self): + pass + + @classmethod + def read_sfm_scene(cls, scene_dir, images_dir, crop_cam): + point_cloud_path = os.path.join(scene_dir, 'points3D.txt') + cameras_path = os.path.join(scene_dir, 'cameras.txt') + images_path = os.path.join(scene_dir, 'images.txt') + captures = cls.read_captures(images_path, cameras_path, images_dir, crop_cam) + if LOAD_PCD: + point_cloud = cls.read_point_cloud(point_cloud_path) + else: + point_cloud = None + sfm_scene = sfm_scenes.SfmScene(captures, point_cloud) + return sfm_scene + + @staticmethod + def read_point_cloud(points_txt_path): + with open(points_txt_path, "r") as fid: + line = fid.readline() + assert line == '# 3D point list with one line of data per point:\n' + line = fid.readline() + assert line == '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n' + line = fid.readline() + assert re.search('^# Number of points: \d+, mean track length: [-+]?\d*\.\d+|\d+\n$', line) + num_points, mean_track_length = re.findall(r"[-+]?\d*\.\d+|\d+", line) + num_points = int(num_points) + mean_track_length = float(mean_track_length) + + xyz = np.zeros((num_points, 3), dtype=np.float32) + rgb = np.zeros((num_points, 3), dtype=np.float32) + if COVISIBILITY_CHECK: + point_meta = {} + + for i in tqdm(range(num_points), desc='reading point cloud'): + elems = fid.readline().split() + xyz[i] = list(map(float, elems[1:4])) + rgb[i] = list(map(int, elems[4:7])) + if COVISIBILITY_CHECK: + point_id = int(elems[0]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point_meta[point_id] = Point3D(id=point_id, + arr_idx=i, + image_ids=image_ids) + pcd = np.concatenate([xyz, rgb], axis=1) + if COVISIBILITY_CHECK: + return pcd, point_meta + else: + return pcd + + @classmethod + def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, crop_cam): + captures = [] + cameras = cls.read_cameras(cameras_txt_path) + images_meta = cls.read_images_meta(images_txt_path, images_dir) + for key in images_meta.keys(): + cur_cam_id = images_meta[key].camera_id + cur_cam = cameras[cur_cam_id] + cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r) + cur_image_path = images_meta[key].image_path + cap = RGBPinholeCapture(cur_image_path, cur_cam, cur_camera_pose, crop_cam) + captures.append(cap) + return captures + + @classmethod + def read_cameras(cls, cameras_txt_path): + cameras = {} + with open(cameras_txt_path, "r") as fid: + line = fid.readline() + assert line == '# Camera list with one line of data per camera:\n' + line = fid.readline() + assert line == '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n' + line = fid.readline() + assert re.search('^# Number of cameras: \d+\n$', line) + num_cams = int(re.findall(r"[-+]?\d*\.\d+|\d+", line)[0]) + + for _ in tqdm(range(num_cams), desc='reading cameras'): + elems = fid.readline().split() + camera_id = int(elems[0]) + camera_type = elems[1] + if camera_type == "PINHOLE": + width, height, focal_length_x, focal_length_y, cx, cy = list(map(float, elems[2:8])) + else: + raise ValueError('Please rectify the 3D model to pinhole camera.') + cur_cam = PinholeCamera(width, height, focal_length_x, focal_length_y, cx, cy) + assert camera_id not in cameras + cameras[camera_id] = cur_cam + return cameras + + @classmethod + def read_images_meta(cls, images_txt_path, images_dir): + images_meta = {} + with open(images_txt_path, "r") as fid: + line = fid.readline() + assert line == '# Image list with two lines of data per image:\n' + line = fid.readline() + assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' + line = fid.readline() + assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n' + line = fid.readline() + assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line) + num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line) + num_images = int(num_images) + mean_ob_per_img = float(mean_ob_per_img) + + for _ in tqdm(range(num_images), desc='reading images meta'): + elems = fid.readline().split() + assert len(elems) == 10 + + image_path = os.path.join(images_dir, elems[9]) + assert os.path.isfile(image_path) + image_id = int(elems[0]) + qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8])) + t = Translation(np.array([tx, ty, tz], dtype=np.float32)) + r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32)) + camera_id = int(elems[8]) + assert image_id not in images_meta + + line = fid.readline() + if COVISIBILITY_CHECK: + elems = line.split() + elems = list(map(float, elems)) + elems = np.array(elems).reshape(-1, 3) + point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int)) + point3d_id = np.sort(np.array(list(point3d_id))) + xyi = elems[elems[:, 2] != -1] + xy = xyi[:, :2] + idx = xyi[:, 2].astype(np.int) + p2d = Point2D(idx, xy) + else: + point3d_id = None + p2d = None + + images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d) + return images_meta + + +class ColmapWithDepthAsciiReader(ColmapAsciiReader): + ''' + Not all images have usable depth estimate from colmap. + A valid list is needed. + ''' + + @classmethod + def read_sfm_scene(cls, scene_dir, images_dir, depth_dir, crop_cam): + point_cloud_path = os.path.join(scene_dir, 'points3D.txt') + cameras_path = os.path.join(scene_dir, 'cameras.txt') + images_path = os.path.join(scene_dir, 'images.txt') + captures = cls.read_captures(images_path, cameras_path, images_dir, depth_dir, crop_cam) + if LOAD_PCD: + point_cloud = cls.read_point_cloud(point_cloud_path) + else: + point_cloud = None + sfm_scene = sfm_scenes.SfmScene(captures, point_cloud) + return sfm_scene + + @classmethod + def read_sfm_scene_given_valid_list_path(cls, scene_dir, images_dir, depth_dir, valid_list_json_path, crop_cam): + point_cloud_path = os.path.join(scene_dir, 'points3D.txt') + cameras_path = os.path.join(scene_dir, 'cameras.txt') + images_path = os.path.join(scene_dir, 'images.txt') + valid_list = cls.read_valid_list(valid_list_json_path) + captures = cls.read_captures_with_depth_given_valid_list(images_path, cameras_path, images_dir, depth_dir, valid_list, crop_cam) + if LOAD_PCD: + point_cloud = cls.read_point_cloud(point_cloud_path) + else: + point_cloud = None + sfm_scene = sfm_scenes.SfmScene(captures, point_cloud) + return sfm_scene + + @classmethod + def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, crop_cam): + captures = [] + cameras = cls.read_cameras(cameras_txt_path) + images_meta = cls.read_images_meta(images_txt_path, images_dir) + for key in images_meta.keys(): + cur_cam_id = images_meta[key].camera_id + cur_cam = cameras[cur_cam_id] + cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r) + cur_image_path = images_meta[key].image_path + try: + cur_depth_path = cls.image_path_2_depth_path(cur_image_path[len(images_dir) + 1:], depth_dir) + except: + print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir)) + # TODO + # continue + # exec(debug_utils.embed_breakpoint()) + cur_depth_path = f'{cur_image_path}dummy' + + cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam) + cap.point3d_id = images_meta[key].point3d_id + cap.p2d = images_meta[key].p2d + cap.image_id = key + captures.append(cap) + return captures + + @classmethod + def read_captures_with_depth_given_valid_list(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, valid_list, crop_cam): + captures = [] + cameras = cls.read_cameras(cameras_txt_path) + images_meta = cls.read_images_meta_given_valid_list(images_txt_path, images_dir, valid_list) + for key in images_meta.keys(): + cur_cam_id = images_meta[key].camera_id + cur_cam = cameras[cur_cam_id] + cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r) + cur_image_path = images_meta[key].image_path + try: + cur_depth_path = cls.image_path_2_depth_path(cur_image_path, depth_dir) + except: + print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir)) + continue + cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam) + cap.point3d_id = images_meta[key].point3d_id + cap.p2d = images_meta[key].p2d + cap.image_id = key + captures.append(cap) + return captures + + @classmethod + def read_images_meta_given_valid_list(cls, images_txt_path, images_dir, valid_list): + images_meta = {} + with open(images_txt_path, "r") as fid: + line = fid.readline() + assert line == '# Image list with two lines of data per image:\n' + line = fid.readline() + assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' + line = fid.readline() + assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n' + line = fid.readline() + assert re.search('^# Number of images: \d+, mean observations per image:[-+]?\d*\.\d+|\d+\n$', line), line + num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line) + num_images = int(num_images) + mean_ob_per_img = float(mean_ob_per_img) + + for _ in tqdm(range(num_images), desc='reading images meta'): + elems = fid.readline().split() + assert len(elems) == 10 + line = fid.readline() + image_path = os.path.join(images_dir, elems[9]) + prefix = os.path.abspath(os.path.join(image_path, '../../../../')) + '/' + rel_image_path = image_path.replace(prefix, '') + if rel_image_path not in valid_list: + continue + assert os.path.isfile(image_path), '{0} is not existing'.format(image_path) + image_id = int(elems[0]) + qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8])) + t = Translation(np.array([tx, ty, tz], dtype=np.float32)) + r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32)) + camera_id = int(elems[8]) + assert image_id not in images_meta + + if COVISIBILITY_CHECK: + elems = line.split() + elems = list(map(float, elems)) + elems = np.array(elems).reshape(-1, 3) + point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int)) + point3d_id = np.sort(np.array(list(point3d_id))) + xyi = elems[elems[:, 2] != -1] + xy = xyi[:, :2] + idx = xyi[:, 2].astype(np.int) + p2d = Point2D(idx, xy) + else: + point3d_id = None + p2d = None + images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d) + return images_meta + + @classmethod + def read_valid_list(cls, valid_list_json_path): + assert os.path.isfile(valid_list_json_path), valid_list_json_path + with open(valid_list_json_path, 'r') as f: + valid_list = json.load(f) + assert len(valid_list) == len(set(valid_list)) + return set(valid_list) + + @classmethod + def image_path_2_depth_path(cls, image_path, depth_dir): + depth_file = os.path.splitext(os.path.basename(image_path))[0] + '.h5' + depth_path = os.path.join(depth_dir, depth_file) + if not os.path.isfile(depth_path): + # depth_file = image_path + '.photometric.bin' + depth_file = image_path + '.geometric.bin' + depth_path = os.path.join(depth_dir, depth_file) + assert os.path.isfile(depth_path), '{0} is not file'.format(depth_path) + return depth_path diff --git a/imcui/third_party/COTR/COTR/datasets/cotr_dataset.py b/imcui/third_party/COTR/COTR/datasets/cotr_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f99cfa54fc30fcad0ed91674912cbc22f4cb4564 --- /dev/null +++ b/imcui/third_party/COTR/COTR/datasets/cotr_dataset.py @@ -0,0 +1,243 @@ +''' +COTR dataset +''' + +import random + +import numpy as np +import torch +from torchvision.transforms import functional as tvtf +from torch.utils import data + +from COTR.datasets import megadepth_dataset +from COTR.utils import debug_utils, utils, constants +from COTR.projector import pcd_projector +from COTR.cameras import capture +from COTR.utils.utils import CropCamConfig +from COTR.inference import inference_helper +from COTR.inference.inference_helper import two_images_side_by_side + + +class COTRDataset(data.Dataset): + def __init__(self, opt, dataset_type: str): + assert dataset_type in ['train', 'val', 'test'] + assert len(opt.scenes_name_list) > 0 + self.opt = opt + self.dataset_type = dataset_type + self.sfm_dataset = megadepth_dataset.MegadepthDataset(opt, dataset_type) + + self.kp_pool = opt.kp_pool + self.num_kp = opt.num_kp + self.bidirectional = opt.bidirectional + self.need_rotation = opt.need_rotation + self.max_rotation = opt.max_rotation + self.rotation_chance = opt.rotation_chance + + def _trim_corrs(self, in_corrs): + length = in_corrs.shape[0] + if length >= self.num_kp: + mask = np.random.choice(length, self.num_kp) + return in_corrs[mask] + else: + mask = np.random.choice(length, self.num_kp - length) + return np.concatenate([in_corrs, in_corrs[mask]], axis=0) + + def __len__(self): + if self.dataset_type == 'val': + return min(1000, self.sfm_dataset.num_queries) + else: + return self.sfm_dataset.num_queries + + def augment_with_rotation(self, query_cap, nn_cap): + if random.random() < self.rotation_chance: + theta = np.random.uniform(low=-1, high=1) * self.max_rotation + query_cap = capture.rotate_capture(query_cap, theta) + if random.random() < self.rotation_chance: + theta = np.random.uniform(low=-1, high=1) * self.max_rotation + nn_cap = capture.rotate_capture(nn_cap, theta) + return query_cap, nn_cap + + def __getitem__(self, index): + assert self.opt.k_size == 1 + query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index) + nn_cap = nn_caps[0] + + if self.need_rotation: + query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap) + + nn_keypoints_y, nn_keypoints_x = np.where(nn_cap.depth_map > 0) + nn_keypoints_y = nn_keypoints_y[..., None] + nn_keypoints_x = nn_keypoints_x[..., None] + nn_keypoints_z = nn_cap.depth_map[np.floor(nn_keypoints_y).astype('int'), np.floor(nn_keypoints_x).astype('int')] + nn_keypoints_xy = np.concatenate([nn_keypoints_x, nn_keypoints_y], axis=1) + nn_keypoints_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(nn_keypoints_xy, nn_keypoints_z, nn_cap.pinhole_cam.intrinsic_mat, motion=nn_cap.cam_pose.camera_to_world, return_index=True) + + query_keypoints_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np( + nn_keypoints_3d_world, + query_cap.pinhole_cam.intrinsic_mat, + query_cap.cam_pose.world_to_camera[0:3, :], + query_cap.image.shape[:2], + keep_z=True, + crop=True, + filter_neg=True, + norm_coord=False, + return_index=True, + ) + query_keypoints_xy = query_keypoints_xyz[:, 0:2] + query_keypoints_z_proj = query_keypoints_xyz[:, 2:3] + query_keypoints_z = query_cap.depth_map[np.floor(query_keypoints_xy[:, 1:2]).astype('int'), np.floor(query_keypoints_xy[:, 0:1]).astype('int')] + mask = (abs(query_keypoints_z - query_keypoints_z_proj) < 0.5)[:, 0] + query_keypoints_xy = query_keypoints_xy[mask] + + if query_keypoints_xy.shape[0] < self.num_kp: + return self.__getitem__(random.randint(0, self.__len__() - 1)) + + nn_keypoints_xy = nn_keypoints_xy[valid_index_1][valid_index_2][mask] + assert nn_keypoints_xy.shape == query_keypoints_xy.shape + corrs = np.concatenate([query_keypoints_xy, nn_keypoints_xy], axis=1) + corrs = self._trim_corrs(corrs) + # flip augmentation + if np.random.uniform() < 0.5: + corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0] + corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2] + sbs_img = two_images_side_by_side(np.fliplr(query_cap.image), np.fliplr(nn_cap.image)) + else: + sbs_img = two_images_side_by_side(query_cap.image, nn_cap.image) + corrs[:, 2] += constants.MAX_SIZE + corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE]) + assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all() + assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all() + assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all() + assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all() + out = { + 'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + 'corrs': torch.from_numpy(corrs).float(), + } + if self.bidirectional: + out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float() + out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float() + else: + out['queries'] = torch.from_numpy(corrs[:, :2]).float() + out['targets'] = torch.from_numpy(corrs[:, 2:]).float() + return out + + +class COTRZoomDataset(COTRDataset): + def __init__(self, opt, dataset_type: str): + assert opt.crop_cam in ['no_crop', 'crop_center'] + assert opt.use_ram == False + super().__init__(opt, dataset_type) + self.zoom_start = opt.zoom_start + self.zoom_end = opt.zoom_end + self.zoom_levels = opt.zoom_levels + self.zoom_jitter = opt.zoom_jitter + self.zooms = np.logspace(np.log10(opt.zoom_start), + np.log10(opt.zoom_end), + num=opt.zoom_levels) + + def get_corrs(self, from_cap, to_cap, reduced_size=None): + from_y, from_x = np.where(from_cap.depth_map > 0) + from_y, from_x = from_y[..., None], from_x[..., None] + if reduced_size is not None: + filter_idx = np.random.choice(from_y.shape[0], reduced_size, replace=False) + from_y, from_x = from_y[filter_idx], from_x[filter_idx] + from_z = from_cap.depth_map[np.floor(from_y).astype('int'), np.floor(from_x).astype('int')] + from_xy = np.concatenate([from_x, from_y], axis=1) + from_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(from_xy, from_z, from_cap.pinhole_cam.intrinsic_mat, motion=from_cap.cam_pose.camera_to_world, return_index=True) + + to_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np( + from_3d_world, + to_cap.pinhole_cam.intrinsic_mat, + to_cap.cam_pose.world_to_camera[0:3, :], + to_cap.image.shape[:2], + keep_z=True, + crop=True, + filter_neg=True, + norm_coord=False, + return_index=True, + ) + + to_xy = to_xyz[:, 0:2] + to_z_proj = to_xyz[:, 2:3] + to_z = to_cap.depth_map[np.floor(to_xy[:, 1:2]).astype('int'), np.floor(to_xy[:, 0:1]).astype('int')] + mask = (abs(to_z - to_z_proj) < 0.5)[:, 0] + if mask.sum() > 0: + return np.concatenate([from_xy[valid_index_1][valid_index_2][mask], to_xy[mask]], axis=1) + else: + return None + + def get_seed_corr(self, from_cap, to_cap, max_try=100): + seed_corr = self.get_corrs(from_cap, to_cap, reduced_size=max_try) + if seed_corr is None: + return None + shuffle = np.random.permutation(seed_corr.shape[0]) + seed_corr = np.take(seed_corr, shuffle, axis=0) + return seed_corr[0] + + def get_zoomed_cap(self, cap, pos, scale, jitter): + patch = inference_helper.get_patch_centered_at(cap.image, pos, scale=scale, return_content=False) + patch = inference_helper.get_patch_centered_at(cap.image, + pos + np.array([patch.w, patch.h]) * np.random.uniform(-jitter, jitter, 2), + scale=scale, + return_content=False) + zoom_config = CropCamConfig(x=patch.x, + y=patch.y, + w=patch.w, + h=patch.h, + out_w=constants.MAX_SIZE, + out_h=constants.MAX_SIZE, + orig_w=cap.shape[1], + orig_h=cap.shape[0]) + zoom_cap = capture.crop_capture(cap, zoom_config) + return zoom_cap + + def __getitem__(self, index): + assert self.opt.k_size == 1 + query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index) + nn_cap = nn_caps[0] + if self.need_rotation: + query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap) + + # find seed + seed_corr = self.get_seed_corr(nn_cap, query_cap) + if seed_corr is None: + return self.__getitem__(random.randint(0, self.__len__() - 1)) + + # crop cap + s = np.random.choice(self.zooms) + nn_zoom_cap = self.get_zoomed_cap(nn_cap, seed_corr[:2], s, 0) + query_zoom_cap = self.get_zoomed_cap(query_cap, seed_corr[2:], s, self.zoom_jitter) + assert nn_zoom_cap.shape == query_zoom_cap.shape == (constants.MAX_SIZE, constants.MAX_SIZE) + corrs = self.get_corrs(query_zoom_cap, nn_zoom_cap) + if corrs is None or corrs.shape[0] < self.num_kp: + return self.__getitem__(random.randint(0, self.__len__() - 1)) + shuffle = np.random.permutation(corrs.shape[0]) + corrs = np.take(corrs, shuffle, axis=0) + corrs = self._trim_corrs(corrs) + + # flip augmentation + if np.random.uniform() < 0.5: + corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0] + corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2] + sbs_img = two_images_side_by_side(np.fliplr(query_zoom_cap.image), np.fliplr(nn_zoom_cap.image)) + else: + sbs_img = two_images_side_by_side(query_zoom_cap.image, nn_zoom_cap.image) + + corrs[:, 2] += constants.MAX_SIZE + corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE]) + assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all() + assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all() + assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all() + assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all() + out = { + 'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + 'corrs': torch.from_numpy(corrs).float(), + } + if self.bidirectional: + out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float() + out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float() + else: + out['queries'] = torch.from_numpy(corrs[:, :2]).float() + out['targets'] = torch.from_numpy(corrs[:, 2:]).float() + + return out diff --git a/imcui/third_party/COTR/COTR/datasets/megadepth_dataset.py b/imcui/third_party/COTR/COTR/datasets/megadepth_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1ac93e369f6ee9998b693f383581403d3369c4 --- /dev/null +++ b/imcui/third_party/COTR/COTR/datasets/megadepth_dataset.py @@ -0,0 +1,140 @@ +''' +dataset specific layer for megadepth +''' + +import os +import json +import random +from collections import namedtuple + +import numpy as np + +from COTR.datasets import colmap_helper +from COTR.global_configs import dataset_config +from COTR.sfm_scenes import knn_search +from COTR.utils import debug_utils, utils, constants + +SceneCapIndex = namedtuple('SceneCapIndex', ['scene_index', 'capture_index']) + + +def prefix_of_img_path_for_magedepth(img_path): + ''' + get the prefix for image of megadepth dataset + ''' + prefix = os.path.abspath(os.path.join(img_path, '../../../..')) + '/' + return prefix + + +class MegadepthSceneDataBase(): + scenes = {} + knn_engine_dict = {} + + @classmethod + def _load_scene(cls, opt, scene_dir_dict): + if scene_dir_dict['scene_dir'] not in cls.scenes: + if opt.info_level == 'rgb': + assert 0 + elif opt.info_level == 'rgbd': + scene_dir = scene_dir_dict['scene_dir'] + images_dir = scene_dir_dict['image_dir'] + depth_dir = scene_dir_dict['depth_dir'] + scene = colmap_helper.ColmapWithDepthAsciiReader.read_sfm_scene_given_valid_list_path(scene_dir, images_dir, depth_dir, dataset_config[opt.dataset_name]['valid_list_json'], opt.crop_cam) + if opt.use_ram: + scene.read_data_to_ram(['image', 'depth']) + else: + raise ValueError() + knn_engine = knn_search.ReprojRatioKnnSearch(scene) + cls.scenes[scene_dir_dict['scene_dir']] = scene + cls.knn_engine_dict[scene_dir_dict['scene_dir']] = knn_engine + else: + pass + + +class MegadepthDataset(): + + def __init__(self, opt, dataset_type): + assert dataset_type in ['train', 'val', 'test'] + assert len(opt.scenes_name_list) > 0 + self.opt = opt + self.dataset_type = dataset_type + self.use_ram = opt.use_ram + self.scenes_name_list = opt.scenes_name_list + self.scenes = None + self.knn_engine_list = None + self.total_caps_set = None + self.query_caps_set = None + self.db_caps_set = None + self.img_path_to_scene_cap_index_dict = {} + self.scene_index_to_db_caps_mask_dict = {} + self._load_scenes() + + @property + def num_scenes(self): + return len(self.scenes) + + @property + def num_queries(self): + return len(self.query_caps_set) + + @property + def num_db(self): + return len(self.db_caps_set) + + def get_scene_cap_index_by_index(self, index): + assert index < len(self.query_caps_set) + img_path = sorted(list(self.query_caps_set))[index] + scene_cap_index = self.img_path_to_scene_cap_index_dict[img_path] + return scene_cap_index + + def _get_common_subset_caps_from_json(self, json_path, total_caps): + prefix = prefix_of_img_path_for_magedepth(list(total_caps)[0]) + with open(json_path, 'r') as f: + common_caps = [prefix + cap for cap in json.load(f)] + common_caps = set(total_caps) & set(common_caps) + return common_caps + + def _extend_img_path_to_scene_cap_index_dict(self, img_path_to_cap_index_dict, scene_id): + for key in img_path_to_cap_index_dict.keys(): + self.img_path_to_scene_cap_index_dict[key] = SceneCapIndex(scene_id, img_path_to_cap_index_dict[key]) + + def _create_scene_index_to_db_caps_mask_dict(self, db_caps_set): + scene_index_to_db_caps_mask_dict = {} + for cap in db_caps_set: + scene_id, cap_id = self.img_path_to_scene_cap_index_dict[cap] + if scene_id not in scene_index_to_db_caps_mask_dict: + scene_index_to_db_caps_mask_dict[scene_id] = [] + scene_index_to_db_caps_mask_dict[scene_id].append(cap_id) + for _k, _v in scene_index_to_db_caps_mask_dict.items(): + scene_index_to_db_caps_mask_dict[_k] = np.array(sorted(_v)) + return scene_index_to_db_caps_mask_dict + + def _load_scenes(self): + scenes = [] + knn_engine_list = [] + total_caps_set = set() + for scene_id, scene_dir_dict in enumerate(self.scenes_name_list): + MegadepthSceneDataBase._load_scene(self.opt, scene_dir_dict) + scene = MegadepthSceneDataBase.scenes[scene_dir_dict['scene_dir']] + knn_engine = MegadepthSceneDataBase.knn_engine_dict[scene_dir_dict['scene_dir']] + total_caps_set = total_caps_set | set(scene.img_path_to_index_dict.keys()) + self._extend_img_path_to_scene_cap_index_dict(scene.img_path_to_index_dict, scene_id) + scenes.append(scene) + knn_engine_list.append(knn_engine) + self.scenes = scenes + self.knn_engine_list = knn_engine_list + self.total_caps_set = total_caps_set + self.query_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name][f'{self.dataset_type}_json'], total_caps_set) + self.db_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name]['train_json'], total_caps_set) + self.scene_index_to_db_caps_mask_dict = self._create_scene_index_to_db_caps_mask_dict(self.db_caps_set) + + def get_query_with_knn(self, index): + scene_index, cap_index = self.get_scene_cap_index_by_index(index) + query_cap = self.scenes[scene_index].captures[cap_index] + knn_engine = self.knn_engine_list[scene_index] + if scene_index in self.scene_index_to_db_caps_mask_dict: + db_mask = self.scene_index_to_db_caps_mask_dict[scene_index] + else: + db_mask = None + pool = knn_engine.get_knn(query_cap, self.opt.pool_size, db_mask=db_mask) + nn_caps = random.sample(pool, min(len(pool), self.opt.k_size)) + return query_cap, nn_caps diff --git a/imcui/third_party/COTR/COTR/global_configs/__init__.py b/imcui/third_party/COTR/COTR/global_configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6db45547cfe941e399571c30b8645a5aa3931368 --- /dev/null +++ b/imcui/third_party/COTR/COTR/global_configs/__init__.py @@ -0,0 +1,10 @@ +import os +import json + +__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) +with open(os.path.join(__location__, 'dataset_config.json'), 'r') as f: + dataset_config = json.load(f) +with open(os.path.join(__location__, 'commons.json'), 'r') as f: + general_config = json.load(f) +# assert os.path.isdir(general_config['out']), f'Please create {general_config["out"]}' +# assert os.path.isdir(general_config['tb_out']), f'Please create {general_config["tb_out"]}' diff --git a/imcui/third_party/COTR/COTR/inference/inference_helper.py b/imcui/third_party/COTR/COTR/inference/inference_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4c477d1f6c60ff0e00a491c1daa3e105c7a59cdf --- /dev/null +++ b/imcui/third_party/COTR/COTR/inference/inference_helper.py @@ -0,0 +1,311 @@ +import warnings + +import cv2 +import numpy as np +import torch +from torchvision.transforms import functional as tvtf +from tqdm import tqdm +import PIL + +from COTR.utils import utils, debug_utils +from COTR.utils.constants import MAX_SIZE +from COTR.cameras.capture import crop_center_max_np, pad_to_square_np +from COTR.utils.utils import ImagePatch + +THRESHOLD_SPARSE = 0.02 +THRESHOLD_PIXELS_RELATIVE = 0.02 +BASE_ZOOM = 1.0 +THRESHOLD_AREA = 0.02 +LARGE_GPU = True + + +def find_prediction_loop(arr): + ''' + loop ends at last element + ''' + assert arr.shape[1] == 2, 'requires shape (N, 2)' + start_index = np.where(np.prod(arr[:-1] == arr[-1], axis=1))[0][0] + return arr[start_index:-1] + + +def two_images_side_by_side(img_a, img_b): + assert img_a.shape == img_b.shape, f'{img_a.shape} vs {img_b.shape}' + assert img_a.dtype == img_b.dtype + h, w, c = img_a.shape + canvas = np.zeros((h, 2 * w, c), dtype=img_a.dtype) + canvas[:, 0 * w:1 * w, :] = img_a + canvas[:, 1 * w:2 * w, :] = img_b + return canvas + + +def to_square_patches(img): + patches = [] + h, w, _ = img.shape + short = size = min(h, w) + long = max(h, w) + if long == short: + patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h) + patches = [patch_0] + elif long <= size * 2: + warnings.warn('Spatial smoothness in dense optical flow is lost, but sparse matching and triangulation should be fine') + patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h) + patch_1 = ImagePatch(img[-size:, -size:], w - size, h - size, size, size, w, h) + patches = [patch_0, patch_1] + # patches += subdivide_patch(patch_0) + # patches += subdivide_patch(patch_1) + else: + raise NotImplementedError + return patches + + +def merge_flow_patches(corrs): + confidence = np.ones([corrs[0].oh, corrs[0].ow]) * 100 + flow = np.zeros([corrs[0].oh, corrs[0].ow, 2]) + cmap = np.ones([corrs[0].oh, corrs[0].ow]) * -1 + for i, c in enumerate(corrs): + temp = np.ones([c.oh, c.ow]) * 100 + temp[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., 2] + tempf = np.zeros([c.oh, c.ow, 2]) + tempf[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., :2] + min_ind = np.stack([temp, confidence], axis=-1).argmin(axis=-1) + min_ind = min_ind == 0 + confidence[min_ind] = temp[min_ind] + flow[min_ind] = tempf[min_ind] + cmap[min_ind] = i + return flow, confidence, cmap + + +def get_patch_centered_at(img, pos, scale=1.0, return_content=True, img_shape=None): + ''' + pos - [x, y] + ''' + if img_shape is None: + img_shape = img.shape + h, w, _ = img_shape + short = min(h, w) + scale = np.clip(scale, 0.0, 1.0) + size = short * scale + size = int((size // 2) * 2) + lu_y = int(pos[1] - size // 2) + lu_x = int(pos[0] - size // 2) + if lu_y < 0: + lu_y -= lu_y + if lu_x < 0: + lu_x -= lu_x + if lu_y + size > h: + lu_y -= (lu_y + size) - (h) + if lu_x + size > w: + lu_x -= (lu_x + size) - (w) + if return_content: + return ImagePatch(img[lu_y:lu_y + size, lu_x:lu_x + size], lu_x, lu_y, size, size, w, h) + else: + return ImagePatch(None, lu_x, lu_y, size, size, w, h) + + +def cotr_patch_flow_exhaustive(model, patches_a, patches_b): + def one_pass(model, img_a, img_b): + device = next(model.parameters()).device + assert img_a.shape[0] == img_a.shape[1] + assert img_b.shape[0] == img_b.shape[1] + img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR)) + img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR)) + img = two_images_side_by_side(img_a, img_b) + img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None] + img = img.to(device) + + q_list = [] + for i in range(MAX_SIZE): + queries = [] + for j in range(MAX_SIZE * 2): + queries.append([(j) / (MAX_SIZE * 2), i / MAX_SIZE]) + queries = np.array(queries) + q_list.append(queries) + if LARGE_GPU: + try: + queries = torch.from_numpy(np.concatenate(q_list))[None].float().to(device) + out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0] + out_list = out.reshape(MAX_SIZE, MAX_SIZE * 2, -1) + except: + assert 0, 'set LARGE_GPU to False' + else: + out_list = [] + for q in q_list: + queries = torch.from_numpy(q)[None].float().to(device) + out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0] + out_list.append(out) + out_list = np.array(out_list) + in_grid = torch.from_numpy(np.array(q_list)).float()[None] * 2 - 1 + out_grid = torch.from_numpy(out_list).float()[None] * 2 - 1 + cycle_grid = torch.nn.functional.grid_sample(out_grid.permute(0, 3, 1, 2), out_grid).permute(0, 2, 3, 1) + confidence = torch.norm(cycle_grid[0, ...] - in_grid[0, ...], dim=-1) + corr = out_grid[0].clone() + corr[:, :MAX_SIZE, 0] = corr[:, :MAX_SIZE, 0] * 2 - 1 + corr[:, MAX_SIZE:, 0] = corr[:, MAX_SIZE:, 0] * 2 + 1 + corr = torch.cat([corr, confidence[..., None]], dim=-1).numpy() + return corr[:, :MAX_SIZE, :], corr[:, MAX_SIZE:, :] + corrs_a = [] + corrs_b = [] + + for p_i in patches_a: + for p_j in patches_b: + c_i, c_j = one_pass(model, p_i.patch, p_j.patch) + base_corners = np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]]) + real_corners_j = (np.array([[p_j.x, p_j.y], [p_j.x + p_j.w, p_j.y], [p_j.x + p_j.w, p_j.y + p_j.h], [p_j.x, p_j.y + p_j.h]]) / np.array([p_j.ow, p_j.oh])) * 2 + np.array([-1, -1]) + real_corners_i = (np.array([[p_i.x, p_i.y], [p_i.x + p_i.w, p_i.y], [p_i.x + p_i.w, p_i.y + p_i.h], [p_i.x, p_i.y + p_i.h]]) / np.array([p_i.ow, p_i.oh])) * 2 + np.array([-1, -1]) + T_i = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_j[:3].astype(np.float32)) + T_j = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_i[:3].astype(np.float32)) + c_i[..., :2] = c_i[..., :2] @ T_i[:2, :2] + T_i[:, 2] + c_j[..., :2] = c_j[..., :2] @ T_j[:2, :2] + T_j[:, 2] + c_i = utils.float_image_resize(c_i, (p_i.h, p_i.w)) + c_j = utils.float_image_resize(c_j, (p_j.h, p_j.w)) + c_i = ImagePatch(c_i, p_i.x, p_i.y, p_i.w, p_i.h, p_i.ow, p_i.oh) + c_j = ImagePatch(c_j, p_j.x, p_j.y, p_j.w, p_j.h, p_j.ow, p_j.oh) + corrs_a.append(c_i) + corrs_b.append(c_j) + return corrs_a, corrs_b + + +def cotr_flow(model, img_a, img_b): + # assert img_a.shape[0] == img_a.shape[1] + # assert img_b.shape[0] == img_b.shape[1] + patches_a = to_square_patches(img_a) + patches_b = to_square_patches(img_b) + + corrs_a, corrs_b = cotr_patch_flow_exhaustive(model, patches_a, patches_b) + corr_a, con_a, cmap_a = merge_flow_patches(corrs_a) + corr_b, con_b, cmap_b = merge_flow_patches(corrs_b) + + resample_a = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_b)[None].float(), + torch.from_numpy(corr_a)[None].float())[0]) + resample_b = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_a)[None].float(), + torch.from_numpy(corr_b)[None].float())[0]) + return corr_a, con_a, resample_a, corr_b, con_b, resample_b + + +def cotr_corr_base(model, img_a, img_b, queries_a): + def one_pass(model, img_a, img_b, queries): + device = next(model.parameters()).device + assert img_a.shape[0] == img_a.shape[1] + assert img_b.shape[0] == img_b.shape[1] + img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR)) + img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR)) + img = two_images_side_by_side(img_a, img_b) + img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None] + img = img.to(device) + + queries = torch.from_numpy(queries)[None].float().to(device) + out = model.forward(img, queries)['pred_corrs'].clone().detach() + cycle = model.forward(img, out)['pred_corrs'].clone().detach() + + queries = queries.cpu().numpy()[0] + out = out.cpu().numpy()[0] + cycle = cycle.cpu().numpy()[0] + conf = np.linalg.norm(queries - cycle, axis=1, keepdims=True) + return np.concatenate([out, conf], axis=1) + + patches_a = to_square_patches(img_a) + patches_b = to_square_patches(img_b) + pred_list = [] + + for p_i in patches_a: + for p_j in patches_b: + normalized_queries_a = queries_a.copy() + mask = (normalized_queries_a[:, 0] >= p_i.x) & (normalized_queries_a[:, 1] >= p_i.y) & (normalized_queries_a[:, 0] <= p_i.x + p_i.w) & (normalized_queries_a[:, 1] <= p_i.y + p_i.h) + normalized_queries_a[:, 0] -= p_i.x + normalized_queries_a[:, 1] -= p_i.y + normalized_queries_a[:, 0] /= 2 * p_i.w + normalized_queries_a[:, 1] /= p_i.h + pred = one_pass(model, p_i.patch, p_j.patch, normalized_queries_a) + pred[~mask, 2] = np.inf + pred[:, 0] -= 0.5 + pred[:, 0] *= 2 * p_j.w + pred[:, 0] += p_j.x + pred[:, 1] *= p_j.h + pred[:, 1] += p_j.y + pred_list.append(pred) + + pred_list = np.stack(pred_list).transpose(1, 0, 2) + out = [] + for item in pred_list: + out.append(item[np.argmin(item[..., 2], axis=0)]) + out = np.array(out)[..., :2] + return np.concatenate([queries_a, out], axis=1) + + +try: + from vispy import gloo + from vispy import app + from vispy.util.ptime import time + from scipy.spatial import Delaunay + from vispy.gloo.wrappers import read_pixels + + app.use_app('glfw') + + + vertex_shader = """ + attribute vec4 color; + attribute vec2 position; + varying vec4 v_color; + void main() + { + gl_Position = vec4(position, 0.0, 1.0); + v_color = color; + } """ + + fragment_shader = """ + varying vec4 v_color; + void main() + { + gl_FragColor = v_color; + } """ + + + class Canvas(app.Canvas): + def __init__(self, mesh, color, size): + # We hide the canvas upon creation. + app.Canvas.__init__(self, show=False, size=size) + self._t0 = time() + # Texture where we render the scene. + self._rendertex = gloo.Texture2D(shape=self.size[::-1] + (4,), internalformat='rgba32f') + # FBO. + self._fbo = gloo.FrameBuffer(self._rendertex, + gloo.RenderBuffer(self.size[::-1])) + # Regular program that will be rendered to the FBO. + self.program = gloo.Program(vertex_shader, fragment_shader) + self.program["position"] = mesh + self.program['color'] = color + # We manually draw the hidden canvas. + self.update() + + def on_draw(self, event): + # Render in the FBO. + with self._fbo: + gloo.clear('black') + gloo.set_viewport(0, 0, *self.size) + self.program.draw() + # Retrieve the contents of the FBO texture. + self.im = read_pixels((0, 0, self.size[0], self.size[1]), True, out_type='float') + self._time = time() - self._t0 + # Immediately exit the application. + app.quit() + + + def triangulate_corr(corr, from_shape, to_shape): + corr = corr.copy() + to_shape = to_shape[:2] + from_shape = from_shape[:2] + corr = corr / np.concatenate([from_shape[::-1], to_shape[::-1]]) + tri = Delaunay(corr[:, :2]) + mesh = corr[:, :2][tri.simplices].astype(np.float32) * 2 - 1 + mesh[..., 1] *= -1 + color = corr[:, 2:][tri.simplices].astype(np.float32) + color = np.concatenate([color, np.ones_like(color[..., 0:2])], axis=-1) + c = Canvas(mesh.reshape(-1, 2), color.reshape(-1, 4), size=(from_shape[::-1])) + app.run() + render = c.im.copy() + render = render[..., :2] + render *= np.array(to_shape[::-1]) + return render +except: + print('cannot use vispy, setting triangulate_corr as None') + triangulate_corr = None diff --git a/imcui/third_party/COTR/COTR/inference/refinement_task.py b/imcui/third_party/COTR/COTR/inference/refinement_task.py new file mode 100644 index 0000000000000000000000000000000000000000..0381b91f548ff71839d363e13002d59a0618c0b6 --- /dev/null +++ b/imcui/third_party/COTR/COTR/inference/refinement_task.py @@ -0,0 +1,191 @@ +import time + +import numpy as np +import torch +from torchvision.transforms import functional as tvtf +import imageio +import PIL + +from COTR.inference.inference_helper import BASE_ZOOM, THRESHOLD_PIXELS_RELATIVE, get_patch_centered_at, two_images_side_by_side, find_prediction_loop +from COTR.utils import debug_utils, utils +from COTR.utils.constants import MAX_SIZE +from COTR.utils.utils import ImagePatch + + +class RefinementTask(): + def __init__(self, image_from, image_to, loc_from, loc_to, area_from, area_to, converge_iters, zoom_ins, identifier=None): + self.identifier = identifier + self.image_from = image_from + self.image_to = image_to + self.loc_from = loc_from + self.best_loc_to = loc_to + self.cur_loc_to = loc_to + self.area_from = area_from + self.area_to = area_to + if self.area_from < self.area_to: + self.s_from = BASE_ZOOM + self.s_to = BASE_ZOOM * np.sqrt(self.area_to / self.area_from) + else: + self.s_to = BASE_ZOOM + self.s_from = BASE_ZOOM * np.sqrt(self.area_from / self.area_to) + + self.cur_job = {} + self.status = 'unfinished' + self.result = 'unknown' + + self.converge_iters = converge_iters + self.zoom_ins = zoom_ins + self.cur_zoom_idx = 0 + self.cur_iter = 0 + self.total_iter = 0 + + self.loc_to_at_zoom = [] + self.loc_history = [loc_to] + self.all_loc_to_dict = {} + self.job_history = [] + self.submitted = False + + @property + def cur_zoom(self): + return self.zoom_ins[self.cur_zoom_idx] + + @property + def confidence_scaling_factor(self): + if self.cur_zoom_idx > 0: + conf_scaling = float(self.cur_zoom) / float(self.zoom_ins[0]) + else: + conf_scaling = 1.0 + return conf_scaling + + def peek(self): + assert self.status == 'unfinished' + patch_from = get_patch_centered_at(None, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False, img_shape=self.image_from.shape) + patch_to = get_patch_centered_at(None, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False, img_shape=self.image_to.shape) + top_job = {'patch_from': patch_from, + 'patch_to': patch_to, + 'loc_from': self.loc_from, + 'loc_to': self.cur_loc_to, + } + return top_job + + def get_task_pilot(self, pilot): + assert self.status == 'unfinished' + patch_from = ImagePatch(None, pilot.cur_job['patch_from'].x, pilot.cur_job['patch_from'].y, pilot.cur_job['patch_from'].w, pilot.cur_job['patch_from'].h, pilot.cur_job['patch_from'].ow, pilot.cur_job['patch_from'].oh) + patch_to = ImagePatch(None, pilot.cur_job['patch_to'].x, pilot.cur_job['patch_to'].y, pilot.cur_job['patch_to'].w, pilot.cur_job['patch_to'].h, pilot.cur_job['patch_to'].ow, pilot.cur_job['patch_to'].oh) + query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float() + self.cur_job = {'patch_from': patch_from, + 'patch_to': patch_to, + 'loc_from': self.loc_from, + 'loc_to': self.cur_loc_to, + 'img': None, + } + self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w)) + assert self.submitted == False + self.submitted = True + return None, query + + def get_task_fast(self): + assert self.status == 'unfinished' + patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False) + patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False) + query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float() + self.cur_job = {'patch_from': patch_from, + 'patch_to': patch_to, + 'loc_from': self.loc_from, + 'loc_to': self.cur_loc_to, + 'img': None, + } + + self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w)) + assert self.submitted == False + self.submitted = True + + return None, query + + def get_task(self): + assert self.status == 'unfinished' + patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom) + patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom) + + query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float() + + img_from = patch_from.patch + img_to = patch_to.patch + assert img_from.shape[0] == img_from.shape[1] + assert img_to.shape[0] == img_to.shape[1] + + img_from = np.array(PIL.Image.fromarray(img_from).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR)) + img_to = np.array(PIL.Image.fromarray(img_to).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR)) + img = two_images_side_by_side(img_from, img_to) + img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float() + + self.cur_job = {'patch_from': ImagePatch(None, patch_from.x, patch_from.y, patch_from.w, patch_from.h, patch_from.ow, patch_from.oh), + 'patch_to': ImagePatch(None, patch_to.x, patch_to.y, patch_to.w, patch_to.h, patch_to.ow, patch_to.oh), + 'loc_from': self.loc_from, + 'loc_to': self.cur_loc_to, + } + + self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w)) + assert self.submitted == False + self.submitted = True + + return img, query + + def next_zoom(self): + if self.cur_zoom_idx >= len(self.zoom_ins) - 1: + self.status = 'finished' + if self.conclude() is None: + self.result = 'bad' + else: + self.result = 'good' + self.cur_zoom_idx += 1 + self.cur_iter = 0 + self.loc_to_at_zoom = [] + + def scale_to_loc(self, raw_to_loc): + raw_to_loc = raw_to_loc.copy() + patch_b = self.cur_job['patch_to'] + raw_to_loc[0] = (raw_to_loc[0] - 0.5) * 2 + loc_to = raw_to_loc * np.array([patch_b.w, patch_b.h]) + loc_to = loc_to + np.array([patch_b.x, patch_b.y]) + return loc_to + + def step(self, raw_to_loc): + assert self.submitted == True + self.submitted = False + loc_to = self.scale_to_loc(raw_to_loc) + self.total_iter += 1 + self.loc_to_at_zoom.append(loc_to) + self.cur_loc_to = loc_to + zoom_finished = False + if self.cur_zoom_idx == len(self.zoom_ins) - 1: + # converge at the last level + if len(self.loc_to_at_zoom) >= 2: + zoom_finished = np.prod(self.loc_to_at_zoom[:-1] == loc_to, axis=1, keepdims=True).any() + if self.cur_iter >= self.converge_iters - 1: + zoom_finished = True + self.cur_iter += 1 + else: + # finish immediately for other levels + zoom_finished = True + if zoom_finished: + self.all_loc_to_dict[self.cur_zoom] = np.array(self.loc_to_at_zoom).copy() + last_level_loc_to = self.all_loc_to_dict[self.cur_zoom] + if len(last_level_loc_to) >= 2: + has_loop = np.prod(last_level_loc_to[:-1] == last_level_loc_to[-1], axis=1, keepdims=True).any() + if has_loop: + loop = find_prediction_loop(last_level_loc_to) + loc_to = loop.mean(axis=0) + self.loc_history.append(loc_to) + self.best_loc_to = loc_to + self.cur_loc_to = self.best_loc_to + self.next_zoom() + + def conclude(self, force=False): + loc_history = np.array(self.loc_history) + if (force == False) and (max(loc_history.std(axis=0)) >= THRESHOLD_PIXELS_RELATIVE * max(*self.image_to.shape)): + return None + return np.concatenate([self.loc_from, self.best_loc_to]) + + def conclude_intermedia(self): + return np.concatenate([np.array(self.loc_history), np.array(self.job_history)], axis=1) diff --git a/imcui/third_party/COTR/COTR/inference/sparse_engine.py b/imcui/third_party/COTR/COTR/inference/sparse_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..d814e09bfbccc78a8263fec9477ba08b5ed12e34 --- /dev/null +++ b/imcui/third_party/COTR/COTR/inference/sparse_engine.py @@ -0,0 +1,427 @@ +''' +Inference engine for sparse image pair correspondences +''' + +import time +import random + +import numpy as np +import torch + +from COTR.inference.inference_helper import THRESHOLD_SPARSE, THRESHOLD_AREA, cotr_flow, cotr_corr_base +from COTR.inference.refinement_task import RefinementTask +from COTR.utils import debug_utils, utils +from COTR.cameras.capture import stretch_to_square_np + + +class SparseEngine(): + def __init__(self, model, batch_size, mode='stretching'): + assert mode in ['stretching', 'tile'] + self.model = model + self.batch_size = batch_size + self.total_tasks = 0 + self.mode = mode + + def form_batch(self, tasks, zoom=None): + counter = 0 + task_ref = [] + img_batch = [] + query_batch = [] + for t in tasks: + if t.status == 'unfinished' and t.submitted == False: + if zoom is not None and t.cur_zoom != zoom: + continue + task_ref.append(t) + img, query = t.get_task() + img_batch.append(img) + query_batch.append(query) + counter += 1 + if counter >= self.batch_size: + break + if len(task_ref) == 0: + return [], [], [] + img_batch = torch.stack(img_batch) + query_batch = torch.stack(query_batch) + return task_ref, img_batch, query_batch + + def infer_batch(self, img_batch, query_batch): + self.total_tasks += img_batch.shape[0] + device = next(self.model.parameters()).device + img_batch = img_batch.to(device) + query_batch = query_batch.to(device) + out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach() + out = out.cpu().numpy()[:, 0, :] + if utils.has_nan(out): + raise ValueError('NaN in prediction') + return out + + def conclude_tasks(self, tasks, return_idx=False, force=False, + offset_x_from=0, + offset_y_from=0, + offset_x_to=0, + offset_y_to=0, + img_a_shape=None, + img_b_shape=None): + corrs = [] + idx = [] + for t in tasks: + if t.status == 'finished': + out = t.conclude(force) + if out is not None: + corrs.append(np.array(out)) + idx.append(t.identifier) + corrs = np.array(corrs) + idx = np.array(idx) + if corrs.shape[0] > 0: + corrs -= np.array([offset_x_from, offset_y_from, offset_x_to, offset_y_to]) + if img_a_shape is not None and img_b_shape is not None and not force: + border_mask = np.prod(corrs < np.concatenate([img_a_shape[::-1], img_b_shape[::-1]]), axis=1) + border_mask = (np.prod(corrs > np.array([0, 0, 0, 0]), axis=1) * border_mask).astype(np.bool) + corrs = corrs[border_mask] + idx = idx[border_mask] + if return_idx: + return corrs, idx + return corrs + + def num_finished_tasks(self, tasks): + counter = 0 + for t in tasks: + if t.status == 'finished': + counter += 1 + return counter + + def num_good_tasks(self, tasks): + counter = 0 + for t in tasks: + if t.result == 'good': + counter += 1 + return counter + + def gen_tasks_w_known_scale(self, img_a, img_b, queries_a, areas, zoom_ins=[1.0], converge_iters=1, max_corrs=1000): + assert self.mode == 'tile' + corr_a = cotr_corr_base(self.model, img_a, img_b, queries_a) + tasks = [] + for c in corr_a: + tasks.append(RefinementTask(img_a, img_b, c[:2], c[2:], areas[0], areas[1], converge_iters, zoom_ins)) + return tasks + + def gen_tasks(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, force=False, areas=None): + if areas is not None: + assert queries_a is not None + assert force == True + assert max_corrs >= queries_a.shape[0] + return self.gen_tasks_w_known_scale(img_a, img_b, queries_a, areas, zoom_ins=zoom_ins, converge_iters=converge_iters, max_corrs=max_corrs) + if self.mode == 'stretching': + if img_a.shape[0] != img_a.shape[1] or img_b.shape[0] != img_b.shape[1]: + img_a_shape = img_a.shape + img_b_shape = img_b.shape + img_a_sq = stretch_to_square_np(img_a.copy()) + img_b_sq = stretch_to_square_np(img_b.copy()) + corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model, + img_a_sq, + img_b_sq + ) + corr_a = utils.float_image_resize(corr_a, img_a_shape[:2]) + con_a = utils.float_image_resize(con_a, img_a_shape[:2]) + resample_a = utils.float_image_resize(resample_a, img_a_shape[:2]) + corr_b = utils.float_image_resize(corr_b, img_b_shape[:2]) + con_b = utils.float_image_resize(con_b, img_b_shape[:2]) + resample_b = utils.float_image_resize(resample_b, img_b_shape[:2]) + else: + corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model, + img_a, + img_b + ) + elif self.mode == 'tile': + corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model, + img_a, + img_b + ) + else: + raise ValueError(f'unsupported mode: {self.mode}') + mask_a = con_a < THRESHOLD_SPARSE + mask_b = con_b < THRESHOLD_SPARSE + area_a = (con_a < THRESHOLD_AREA).sum() / mask_a.size + area_b = (con_b < THRESHOLD_AREA).sum() / mask_b.size + tasks = [] + + if queries_a is None: + index_a = np.where(mask_a) + index_a = np.array(index_a).T + index_a = index_a[np.random.choice(len(index_a), min(max_corrs, len(index_a)))] + index_b = np.where(mask_b) + index_b = np.array(index_b).T + index_b = index_b[np.random.choice(len(index_b), min(max_corrs, len(index_b)))] + for pos in index_a: + loc_from = pos[::-1] + loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1] + tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins)) + for pos in index_b: + ''' + trick: suppose to fix the query point location(loc_from), + but here it fixes the first guess(loc_to). + ''' + loc_from = pos[::-1] + loc_to = (corr_b[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_a.shape[:2][::-1] + tasks.append(RefinementTask(img_a, img_b, loc_to, loc_from, area_a, area_b, converge_iters, zoom_ins)) + else: + if force: + for i, loc_from in enumerate(queries_a): + pos = loc_from[::-1] + pos = np.array([np.clip(pos[0], 0, corr_a.shape[0] - 1), np.clip(pos[1], 0, corr_a.shape[1] - 1)], dtype=np.int) + loc_to = (corr_a[tuple(pos)].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1] + tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i)) + else: + for i, loc_from in enumerate(queries_a): + pos = loc_from[::-1] + if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any(): + continue + if mask_a[tuple(np.floor(pos).astype('int'))]: + loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1] + tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i)) + if len(tasks) < max_corrs: + extra = max_corrs - len(tasks) + counter = 0 + for i, loc_from in enumerate(queries_a): + if counter >= extra: + break + pos = loc_from[::-1] + if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any(): + continue + if mask_a[tuple(np.floor(pos).astype('int'))] == False: + loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1] + tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i)) + counter += 1 + return tasks + + def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None): + ''' + currently only support fixed queries_a + ''' + img_a = img_a.copy() + img_b = img_b.copy() + img_a_shape = img_a.shape[:2] + img_b_shape = img_b.shape[:2] + if queries_a is not None: + queries_a = queries_a.copy() + tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas) + while True: + num_g = self.num_good_tasks(tasks) + print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}') + task_ref, img_batch, query_batch = self.form_batch(tasks) + if len(task_ref) == 0: + break + if num_g >= max_corrs: + break + out = self.infer_batch(img_batch, query_batch) + for t, o in zip(task_ref, out): + t.step(o) + if return_tasks_only: + return tasks + if return_idx: + corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force, + img_a_shape=img_a_shape, + img_b_shape=img_b_shape,) + corrs = corrs[:max_corrs] + idx = idx[:max_corrs] + return corrs, idx + else: + corrs = self.conclude_tasks(tasks, force=force, + img_a_shape=img_a_shape, + img_b_shape=img_b_shape,) + corrs = corrs[:max_corrs] + return corrs + + def cotr_corr_multiscale_with_cycle_consistency(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, return_cycle_error=False): + EXTRACTION_RATE = 0.3 + temp_max_corrs = int(max_corrs / EXTRACTION_RATE) + if queries_a is not None: + temp_max_corrs = min(temp_max_corrs, queries_a.shape[0]) + queries_a = queries_a.copy() + corr_f, idx_f = self.cotr_corr_multiscale(img_a.copy(), img_b.copy(), + zoom_ins=zoom_ins, + converge_iters=converge_iters, + max_corrs=temp_max_corrs, + queries_a=queries_a, + return_idx=True) + assert corr_f.shape[0] > 0 + corr_b, idx_b = self.cotr_corr_multiscale(img_b.copy(), img_a.copy(), + zoom_ins=zoom_ins, + converge_iters=converge_iters, + max_corrs=corr_f.shape[0], + queries_a=corr_f[:, 2:].copy(), + return_idx=True) + assert corr_b.shape[0] > 0 + cycle_errors = np.linalg.norm(corr_f[idx_b][:, :2] - corr_b[:, 2:], axis=1) + order = np.argsort(cycle_errors) + out = [corr_f[idx_b][order][:max_corrs]] + if return_idx: + out.append(idx_f[idx_b][order][:max_corrs]) + if return_cycle_error: + out.append(cycle_errors[order][:max_corrs]) + if len(out) == 1: + out = out[0] + return out + + +class FasterSparseEngine(SparseEngine): + ''' + search and merge nearby tasks to accelerate inference speed. + It will make spatial accuracy slightly worse. + ''' + + def __init__(self, model, batch_size, mode='stretching', max_load=256): + super().__init__(model, batch_size, mode=mode) + self.max_load = max_load + + def infer_batch_grouped(self, img_batch, query_batch): + device = next(self.model.parameters()).device + img_batch = img_batch.to(device) + query_batch = query_batch.to(device) + out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach().cpu().numpy() + return out + + def get_tasks_map(self, zoom, tasks): + maps = [] + ids = [] + for i, t in enumerate(tasks): + if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom: + t_info = t.peek() + point = np.concatenate([t_info['loc_from'], t_info['loc_to']]) + maps.append(point) + ids.append(i) + return np.array(maps), np.array(ids) + + def form_squad(self, zoom, pilot, pilot_id, tasks, tasks_map, task_ids, bookkeeping): + assert pilot.status == 'unfinished' and pilot.submitted == False and pilot.cur_zoom == zoom + SAFE_AREA = 0.5 + pilot_info = pilot.peek() + pilot_from_center_x = pilot_info['patch_from'].x + pilot_info['patch_from'].w/2 + pilot_from_center_y = pilot_info['patch_from'].y + pilot_info['patch_from'].h/2 + pilot_from_left = pilot_from_center_x - pilot_info['patch_from'].w/2 * SAFE_AREA + pilot_from_right = pilot_from_center_x + pilot_info['patch_from'].w/2 * SAFE_AREA + pilot_from_upper = pilot_from_center_y - pilot_info['patch_from'].h/2 * SAFE_AREA + pilot_from_lower = pilot_from_center_y + pilot_info['patch_from'].h/2 * SAFE_AREA + + pilot_to_center_x = pilot_info['patch_to'].x + pilot_info['patch_to'].w/2 + pilot_to_center_y = pilot_info['patch_to'].y + pilot_info['patch_to'].h/2 + pilot_to_left = pilot_to_center_x - pilot_info['patch_to'].w/2 * SAFE_AREA + pilot_to_right = pilot_to_center_x + pilot_info['patch_to'].w/2 * SAFE_AREA + pilot_to_upper = pilot_to_center_y - pilot_info['patch_to'].h/2 * SAFE_AREA + pilot_to_lower = pilot_to_center_y + pilot_info['patch_to'].h/2 * SAFE_AREA + + img, query = pilot.get_task() + assert pilot.submitted == True + members = [pilot] + queries = [query] + bookkeeping[pilot_id] = False + + loads = np.where(((tasks_map[:, 0] > pilot_from_left) & + (tasks_map[:, 0] < pilot_from_right) & + (tasks_map[:, 1] > pilot_from_upper) & + (tasks_map[:, 1] < pilot_from_lower) & + (tasks_map[:, 2] > pilot_to_left) & + (tasks_map[:, 2] < pilot_to_right) & + (tasks_map[:, 3] > pilot_to_upper) & + (tasks_map[:, 3] < pilot_to_lower)) * + bookkeeping)[0][: self.max_load] + + for ti in task_ids[loads]: + t = tasks[ti] + assert t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom + _, query = t.get_task_pilot(pilot) + members.append(t) + queries.append(query) + queries = torch.stack(queries, axis=1) + bookkeeping[loads] = False + return members, img, queries, bookkeeping + + def form_grouped_batch(self, zoom, tasks): + counter = 0 + task_ref = [] + img_batch = [] + query_batch = [] + tasks_map, task_ids = self.get_tasks_map(zoom, tasks) + shuffle = np.random.permutation(tasks_map.shape[0]) + tasks_map = np.take(tasks_map, shuffle, axis=0) + task_ids = np.take(task_ids, shuffle, axis=0) + bookkeeping = np.ones_like(task_ids).astype(bool) + + for i, ti in enumerate(task_ids): + t = tasks[ti] + if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom: + members, img, queries, bookkeeping = self.form_squad(zoom, t, i, tasks, tasks_map, task_ids, bookkeeping) + task_ref.append(members) + img_batch.append(img) + query_batch.append(queries) + counter += 1 + if counter >= self.batch_size: + break + if len(task_ref) == 0: + return [], [], [] + + max_len = max([q.shape[1] for q in query_batch]) + for i in range(len(query_batch)): + q = query_batch[i] + query_batch[i] = torch.cat([q, torch.zeros([1, max_len - q.shape[1], 2])], axis=1) + img_batch = torch.stack(img_batch) + query_batch = torch.cat(query_batch) + return task_ref, img_batch, query_batch + + def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None): + ''' + currently only support fixed queries_a + ''' + img_a = img_a.copy() + img_b = img_b.copy() + img_a_shape = img_a.shape[:2] + img_b_shape = img_b.shape[:2] + if queries_a is not None: + queries_a = queries_a.copy() + tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas) + for zm in zoom_ins: + print(f'======= Zoom: {zm} ======') + while True: + num_g = self.num_good_tasks(tasks) + task_ref, img_batch, query_batch = self.form_grouped_batch(zm, tasks) + if len(task_ref) == 0: + break + if num_g >= max_corrs: + break + out = self.infer_batch_grouped(img_batch, query_batch) + num_steps = 0 + for i, temp in enumerate(task_ref): + for j, t in enumerate(temp): + t.step(out[i, j]) + num_steps += 1 + print(f'solved {num_steps} sub-tasks in one invocation with {img_batch.shape[0]} image pairs') + if num_steps <= self.batch_size: + break + # Rollback to default inference, because of too few valid tasks can be grouped together. + while True: + num_g = self.num_good_tasks(tasks) + print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}') + task_ref, img_batch, query_batch = self.form_batch(tasks, zm) + if len(task_ref) == 0: + break + if num_g >= max_corrs: + break + out = self.infer_batch(img_batch, query_batch) + for t, o in zip(task_ref, out): + t.step(o) + + if return_tasks_only: + return tasks + if return_idx: + corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force, + img_a_shape=img_a_shape, + img_b_shape=img_b_shape,) + corrs = corrs[:max_corrs] + idx = idx[:max_corrs] + return corrs, idx + else: + corrs = self.conclude_tasks(tasks, force=force, + img_a_shape=img_a_shape, + img_b_shape=img_b_shape,) + corrs = corrs[:max_corrs] + return corrs diff --git a/imcui/third_party/COTR/COTR/models/__init__.py b/imcui/third_party/COTR/COTR/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..263568f9eb947e24801c8e27242c66a12b4a97ed --- /dev/null +++ b/imcui/third_party/COTR/COTR/models/__init__.py @@ -0,0 +1,10 @@ +''' +The COTR model is modified from DETR code base. +https://github.com/facebookresearch/detr +''' +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .cotr_model import build + + +def build_model(args): + return build(args) diff --git a/imcui/third_party/COTR/COTR/models/backbone.py b/imcui/third_party/COTR/COTR/models/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5b5def2bfe403d3d12077b48ad76d8d97f2f84 --- /dev/null +++ b/imcui/third_party/COTR/COTR/models/backbone.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from .misc import NestedTensor + +from .position_encoding import build_position_encoding +from COTR.utils import debug_utils, constants + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool, layer='layer3'): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + # print(f'freeze {name}') + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {layer: "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward_raw(self, x): + y = self.body(x) + assert len(y.keys()) == 1 + return y['0'] + + def forward(self, tensor_list: NestedTensor): + assert tensor_list.tensors.shape[-2:] == (constants.MAX_SIZE, constants.MAX_SIZE * 2) + left = self.body(tensor_list.tensors[..., 0:constants.MAX_SIZE]) + right = self.body(tensor_list.tensors[..., constants.MAX_SIZE:2 * constants.MAX_SIZE]) + xs = {} + for k in left.keys(): + xs[k] = torch.cat([left[k], right[k]], dim=-1) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool, + layer='layer3', + num_channels=1024): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=True, norm_layer=FrozenBatchNorm2d) + super().__init__(backbone, train_backbone, num_channels, return_interm_layers, layer) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + if hasattr(args, 'lr_backbone'): + train_backbone = args.lr_backbone > 0 + else: + train_backbone = False + backbone = Backbone(args.backbone, train_backbone, False, args.dilation, layer=args.layer, num_channels=args.dim_feedforward) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/imcui/third_party/COTR/COTR/models/cotr_model.py b/imcui/third_party/COTR/COTR/models/cotr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..72eced8ada79d04898acfeacd7cf61516b52f8a3 --- /dev/null +++ b/imcui/third_party/COTR/COTR/models/cotr_model.py @@ -0,0 +1,51 @@ +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from COTR.utils import debug_utils, constants, utils +from .misc import (NestedTensor, nested_tensor_from_tensor_list) +from .backbone import build_backbone +from .transformer import build_transformer +from .position_encoding import NerfPositionalEncoding, MLP + + +class COTR(nn.Module): + + def __init__(self, backbone, transformer, sine_type='lin_sine'): + super().__init__() + self.transformer = transformer + hidden_dim = transformer.d_model + self.corr_embed = MLP(hidden_dim, hidden_dim, 2, 3) + self.query_proj = NerfPositionalEncoding(hidden_dim // 4, sine_type) + self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) + self.backbone = backbone + + def forward(self, samples: NestedTensor, queries): + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + src, mask = features[-1].decompose() + assert mask is not None + _b, _q, _ = queries.shape + queries = queries.reshape(-1, 2) + queries = self.query_proj(queries).reshape(_b, _q, -1) + queries = queries.permute(1, 0, 2) + hs = self.transformer(self.input_proj(src), mask, queries, pos[-1])[0] + outputs_corr = self.corr_embed(hs) + out = {'pred_corrs': outputs_corr[-1]} + return out + + +def build(args): + backbone = build_backbone(args) + transformer = build_transformer(args) + model = COTR( + backbone, + transformer, + sine_type=args.position_embedding, + ) + return model diff --git a/imcui/third_party/COTR/COTR/models/misc.py b/imcui/third_party/COTR/COTR/models/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f6093194fe1b68a75f819cd90a233d7cb00867a7 --- /dev/null +++ b/imcui/third_party/COTR/COTR/models/misc.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if float(torchvision.__version__.split('.')[1]) < 7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + diff --git a/imcui/third_party/COTR/COTR/models/position_encoding.py b/imcui/third_party/COTR/COTR/models/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..9207015d2671b87ec80c58b1f711e185af5db8de --- /dev/null +++ b/imcui/third_party/COTR/COTR/models/position_encoding.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn +import torch.nn.functional as F + +from .misc import NestedTensor +from COTR.utils import debug_utils + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class NerfPositionalEncoding(nn.Module): + def __init__(self, depth=10, sine_type='lin_sine'): + ''' + out_dim = in_dim * depth * 2 + ''' + super().__init__() + if sine_type == 'lin_sine': + self.bases = [i+1 for i in range(depth)] + elif sine_type == 'exp_sine': + self.bases = [2**i for i in range(depth)] + print(f'using {sine_type} as positional encoding') + + @torch.no_grad() + def forward(self, inputs): + out = torch.cat([torch.sin(i * math.pi * inputs) for i in self.bases] + [torch.cos(i * math.pi * inputs) for i in self.bases], axis=-1) + assert torch.isnan(out).any() == False + return out + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None, sine_type='lin_sine'): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.sine = NerfPositionalEncoding(num_pos_feats//2, sine_type) + + @torch.no_grad() + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = (y_embed-0.5) / (y_embed[:, -1:, :] + eps) + x_embed = (x_embed-0.5) / (x_embed[:, :, -1:] + eps) + pos = torch.stack([x_embed, y_embed], dim=-1) + return self.sine(pos).permute(0, 3, 1, 2) + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('lin_sine', 'exp_sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True, sine_type=args.position_embedding) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/imcui/third_party/COTR/COTR/models/transformer.py b/imcui/third_party/COTR/COTR/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..69ba7e318dd4c0ba80042742e19360ec6dfad683 --- /dev/null +++ b/imcui/third_party/COTR/COTR/models/transformer.py @@ -0,0 +1,228 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +COTR/DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from COTR.utils import debug_utils + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(query=q, + key=k, + value=src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu"): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/imcui/third_party/COTR/COTR/options/options.py b/imcui/third_party/COTR/COTR/options/options.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b7d8b739b4de554d7c3c0bb3679a5f10434659 --- /dev/null +++ b/imcui/third_party/COTR/COTR/options/options.py @@ -0,0 +1,52 @@ +import sys +import argparse +import json +import os + + +from COTR.options.options_utils import str2bool +from COTR.options import options_utils +from COTR.global_configs import general_config, dataset_config +from COTR.utils import debug_utils + + +def set_general_arguments(parser): + general_arg = parser.add_argument_group('General') + general_arg.add_argument('--confirm', type=str2bool, + default=True, help='promote confirmation for user') + general_arg.add_argument('--use_cuda', type=str2bool, + default=True, help='use cuda') + general_arg.add_argument('--use_cc', type=str2bool, + default=False, help='use computecanada') + + +def set_dataset_arguments(parser): + data_arg = parser.add_argument_group('Data') + data_arg.add_argument('--dataset_name', type=str, default='megadepth', help='dataset name') + data_arg.add_argument('--shuffle_data', type=str2bool, default=True, help='use sequence dataset or shuffled dataset') + data_arg.add_argument('--use_ram', type=str2bool, default=False, help='load image/depth/pcd to ram') + data_arg.add_argument('--info_level', choices=['rgb', 'rgbd'], type=str, default='rgbd', help='the information level of dataset') + data_arg.add_argument('--scene_file', type=str, default=None, required=False, help='what scene/seq want to use') + data_arg.add_argument('--workers', type=int, default=0, help='worker for loading data') + data_arg.add_argument('--crop_cam', choices=['no_crop', 'crop_center', 'crop_center_and_resize'], type=str, default='crop_center_and_resize', help='crop the center of image to avoid changing aspect ratio, resize to make the operations batch-able.') + + +def set_nn_arguments(parser): + nn_arg = parser.add_argument_group('Nearest neighbors') + nn_arg.add_argument('--nn_method', choices=['netvlad', 'overlapping'], type=str, default='overlapping', help='how to select nearest neighbors') + nn_arg.add_argument('--pool_size', type=int, default=20, help='a pool of sorted nn candidates') + nn_arg.add_argument('--k_size', type=int, default=1, help='select the nn randomly from pool') + + +def set_COTR_arguments(parser): + cotr_arg = parser.add_argument_group('COTR model') + cotr_arg.add_argument('--backbone', type=str, default='resnet50') + cotr_arg.add_argument('--hidden_dim', type=int, default=256) + cotr_arg.add_argument('--dilation', type=str2bool, default=False) + cotr_arg.add_argument('--dropout', type=float, default=0.1) + cotr_arg.add_argument('--nheads', type=int, default=8) + cotr_arg.add_argument('--layer', type=str, default='layer3', help='which layer from resnet') + cotr_arg.add_argument('--enc_layers', type=int, default=6) + cotr_arg.add_argument('--dec_layers', type=int, default=6) + cotr_arg.add_argument('--position_embedding', type=str, default='lin_sine', help='sine wave type') + diff --git a/imcui/third_party/COTR/COTR/options/options_utils.py b/imcui/third_party/COTR/COTR/options/options_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..12fc1f1ae98ec89ad249fd99c7f9415e350f93de --- /dev/null +++ b/imcui/third_party/COTR/COTR/options/options_utils.py @@ -0,0 +1,108 @@ +'''utils for argparse +''' + +import sys +import os +from os import path +import time +import json + +from COTR.utils import utils, debug_utils +from COTR.global_configs import general_config, dataset_config + + +def str2bool(v: str) -> bool: + return v.lower() in ('true', '1', 'yes', 'y', 't') + + +def get_compact_naming_cotr(opt) -> str: + base_str = 'model:cotr_{0}_{1}_{2}_dset:{3}_bs:{4}_pe:{5}_lrbackbone:{6}' + result = base_str.format(opt.backbone, + opt.layer, + opt.dim_feedforward, + opt.dataset_name, + opt.batch_size, + opt.position_embedding, + opt.lr_backbone, + ) + if opt.suffix: + result = result + '_suffix:{0}'.format(opt.suffix) + return result + + +def print_opt(opt): + content_list = [] + args = list(vars(opt)) + args.sort() + for arg in args: + content_list += [arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg))] + utils.print_notification(content_list, 'OPTIONS') + + +def confirm_opt(opt): + print_opt(opt) + if opt.use_cc == False: + if not utils.confirm(): + exit(1) + + +def opt_to_string(opt) -> str: + string = '\n\n' + string += 'python ' + ' '.join(sys.argv) + string += '\n\n' + # string += '---------------------- CONFIG ----------------------\n' + args = list(vars(opt)) + args.sort() + for arg in args: + string += arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg)) + '\n\n' + # string += '----------------------------------------------------\n' + return string + + +def save_opt(opt): + '''save options to a json file + ''' + if not os.path.exists(opt.out): + os.makedirs(opt.out) + json_path = os.path.join(opt.out, 'params.json') + if 'debug' not in opt.suffix and path.isfile(json_path): + assert opt.resume, 'You are trying to modify a model without resuming: {0}'.format(opt.out) + old_dict = json.load(open(json_path)) + new_dict = vars(opt) + # assert old_dict.keys() == new_dict.keys(), 'New configuration keys is different from old one.\nold: {0}\nnew: {1}'.format(old_dict.keys(), new_dict.keys()) + if new_dict != old_dict: + exception_keys = ['command'] + for key in set(old_dict.keys()).union(set(new_dict.keys())): + if key not in exception_keys: + old_val = old_dict[key] if key in old_dict else 'not exists(old)' + new_val = new_dict[key] if key in old_dict else 'not exists(new)' + if old_val != new_val: + print('key: {0}, old_val: {1}, new_val: {2}'.format(key, old_val, new_val)) + if opt.use_cc == False: + if not utils.confirm('Please manually confirm'): + exit(1) + with open(json_path, 'w') as fp: + json.dump(vars(opt), fp, indent=0, sort_keys=True) + + +def build_scenes_name_list_from_opt(opt): + if hasattr(opt, 'scene_file') and opt.scene_file is not None: + assert os.path.isfile(opt.scene_file), opt.scene_file + with open(opt.scene_file, 'r') as f: + scenes_list = json.load(f) + else: + scenes_list = [{'scene': opt.scene, 'seq': opt.seq}] + if 'megadepth' in opt.dataset_name: + assert opt.info_level in ['rgb', 'rgbd'] + scenes_name_list = [] + if opt.info_level == 'rgb': + dir_list = ['scene_dir', 'image_dir'] + elif opt.info_level == 'rgbd': + dir_list = ['scene_dir', 'image_dir', 'depth_dir'] + dir_list = {dir_name: dataset_config[opt.dataset_name][dir_name] for dir_name in dir_list} + for item in scenes_list: + cur_scene = {key: val.format(item['scene'], item['seq']) for key, val in dir_list.items()} + scenes_name_list.append(cur_scene) + else: + raise NotImplementedError() + return scenes_name_list diff --git a/imcui/third_party/COTR/COTR/projector/pcd_projector.py b/imcui/third_party/COTR/COTR/projector/pcd_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..4847bd50920c84066fa7b36ea1c434521d7f12d9 --- /dev/null +++ b/imcui/third_party/COTR/COTR/projector/pcd_projector.py @@ -0,0 +1,210 @@ +''' +a point cloud projector based on np +''' + +import numpy as np + +from COTR.utils import debug_utils, utils + + +def render_point_cloud_at_capture(point_cloud, capture, render_type='rgb', return_pcd=False): + assert render_type in ['rgb', 'bw', 'depth'] + if render_type == 'rgb': + assert point_cloud.shape[1] == 6 + else: + point_cloud = point_cloud[:, :3] + assert point_cloud.shape[1] == 3 + if render_type in ['bw', 'rgb']: + keep_z = False + else: + keep_z = True + + pcd_2d = PointCloudProjector.pcd_3d_to_pcd_2d_np(point_cloud, + capture.intrinsic_mat, + capture.extrinsic_mat, + capture.size, + keep_z=True, + crop=True, + filter_neg=True, + norm_coord=False, + return_index=False) + reproj = PointCloudProjector.pcd_2d_to_img_2d_np(pcd_2d, + capture.size, + has_z=True, + keep_z=keep_z) + if return_pcd: + return reproj, pcd_2d + else: + return reproj + + +def optical_flow_from_a_to_b(cap_a, cap_b): + cap_a_intrinsic = cap_a.pinhole_cam.intrinsic_mat + cap_a_img_size = cap_a.pinhole_cam.shape[:2] + _h, _w = cap_b.pinhole_cam.shape[:2] + x, y = np.meshgrid( + np.linspace(0, _w - 1, num=_w), + np.linspace(0, _h - 1, num=_h), + ) + coord_map = np.concatenate([np.expand_dims(x, 2), np.expand_dims(y, 2)], axis=2) + pcd_from_cap_b = cap_b.get_point_cloud_world_from_depth(coord_map) + # pcd_from_cap_b = cap_b.point_cloud_world_w_feat(['pos', 'coord']) + optical_flow = PointCloudProjector.pcd_2d_to_img_2d_np(PointCloudProjector.pcd_3d_to_pcd_2d_np(pcd_from_cap_b, cap_a_intrinsic, cap_a.cam_pose.world_to_camera[0:3, :], cap_a_img_size, keep_z=True, crop=True, filter_neg=True, norm_coord=False), cap_a_img_size, has_z=True, keep_z=False) + return optical_flow + + +class PointCloudProjector(): + def __init__(self): + pass + + @staticmethod + def pcd_2d_to_pcd_3d_np(pcd, depth, intrinsic, motion=None, return_index=False): + assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd)) + assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic)) + assert len(pcd.shape) == 2 and pcd.shape[1] >= 2 + assert len(depth.shape) == 2 and depth.shape[1] == 1 + assert intrinsic.shape == (3, 3) + if motion is not None: + assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion)) + assert motion.shape == (4, 4) + # exec(debug_utils.embed_breakpoint()) + x, y, z = pcd[:, 0], pcd[:, 1], depth[:, 0] + append_ones = np.ones_like(x) + xyz = np.stack([x, y, append_ones], axis=1) # shape: [num_points, 3] + inv_intrinsic_mat = np.linalg.inv(intrinsic) + xyz = np.matmul(inv_intrinsic_mat, xyz.T).T * z[..., None] + valid_mask_1 = np.where(xyz[:, 2] > 0) + xyz = xyz[valid_mask_1] + + if motion is not None: + append_ones = np.ones_like(xyz[:, 0:1]) + xyzw = np.concatenate([xyz, append_ones], axis=1) + xyzw = np.matmul(motion, xyzw.T).T + valid_mask_2 = np.where(xyzw[:, 3] != 0) + xyzw = xyzw[valid_mask_2] + xyzw /= xyzw[:, 3:4] + xyz = xyzw[:, 0:3] + + if pcd.shape[1] > 2: + features = pcd[:, 2:] + try: + features = features[valid_mask_1][valid_mask_2] + except UnboundLocalError: + features = features[valid_mask_1] + assert xyz.shape[0] == features.shape[0] + xyz = np.concatenate([xyz, features], axis=1) + if return_index: + points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2] + return xyz, points_index + return xyz + + @staticmethod + def img_2d_to_pcd_3d_np(depth, intrinsic, img=None, motion=None): + ''' + the function signature is not fully correct, because img is an optional + if motion is None, the output pcd is in camera space + if motion is camera_to_world, the out pcd is in world space. + here the output is pure np array + ''' + + assert isinstance(depth, np.ndarray), 'cannot process data type: {0}'.format(type(depth)) + assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic)) + assert len(depth.shape) == 2 + assert intrinsic.shape == (3, 3) + if img is not None: + assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img)) + assert len(img.shape) == 3 + assert img.shape[:2] == depth.shape[:2], 'feature should have the same resolution as the depth' + if motion is not None: + assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion)) + assert motion.shape == (4, 4) + + pcd_image_space = PointCloudProjector.img_2d_to_pcd_2d_np(depth[..., None], norm_coord=False) + valid_mask_1 = np.where(pcd_image_space[:, 2] > 0) + pcd_image_space = pcd_image_space[valid_mask_1] + xy = pcd_image_space[:, :2] + z = pcd_image_space[:, 2:3] + if img is not None: + _c = img.shape[-1] + feat = img.reshape(-1, _c) + feat = feat[valid_mask_1] + xy = np.concatenate([xy, feat], axis=1) + pcd_3d = PointCloudProjector.pcd_2d_to_pcd_3d_np(xy, z, intrinsic, motion=motion) + return pcd_3d + + @staticmethod + def pcd_3d_to_pcd_2d_np(pcd, intrinsic, extrinsic, size, keep_z: bool, crop: bool = True, filter_neg: bool = True, norm_coord: bool = True, return_index: bool = False): + assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd)) + assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic)) + assert isinstance(extrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(extrinsic)) + assert len(pcd.shape) == 2 and pcd.shape[1] >= 3, 'seems the input pcd is not a valid 3d point cloud: {0}'.format(pcd.shape) + + xyzw = np.concatenate([pcd[:, 0:3], np.ones_like(pcd[:, 0:1])], axis=1) + mvp_mat = np.matmul(intrinsic, extrinsic) + camera_points = np.matmul(mvp_mat, xyzw.T).T + if filter_neg: + valid_mask_1 = camera_points[:, 2] > 0.0 + else: + valid_mask_1 = np.ones_like(camera_points[:, 2], dtype=bool) + camera_points = camera_points[valid_mask_1] + image_points = camera_points / camera_points[:, 2:3] + image_points = image_points[:, :2] + if crop: + valid_mask_2 = (image_points[:, 0] >= 0) * (image_points[:, 0] < size[1] - 1) * (image_points[:, 1] >= 0) * (image_points[:, 1] < size[0] - 1) + else: + valid_mask_2 = np.ones_like(image_points[:, 0], dtype=bool) + if norm_coord: + image_points = ((image_points / size[::-1]) * 2) - 1 + + if keep_z: + image_points = np.concatenate([image_points[valid_mask_2], camera_points[valid_mask_2][:, 2:3], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1) + else: + image_points = np.concatenate([image_points[valid_mask_2], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1) + # if filter_neg and crop: + # exec(debug_utils.embed_breakpoint('pcd_3d_to_pcd_2d_np')) + if return_index: + points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2] + return image_points, points_index + return image_points + + @staticmethod + def pcd_2d_to_img_2d_np(pcd, size, has_z=False, keep_z=False): + assert len(pcd.shape) == 2 and pcd.shape[-1] >= 2, 'seems the input pcd is not a valid point cloud: {0}'.format(pcd.shape) + # assert 0, 'pass Z values in' + if has_z: + pcd = pcd[pcd[:, 2].argsort()[::-1]] + if not keep_z: + pcd = np.delete(pcd, [2], axis=1) + index_list = np.round(pcd[:, 0:2]).astype(np.int32) + index_list[:, 0] = np.clip(index_list[:, 0], 0, size[1] - 1) + index_list[:, 1] = np.clip(index_list[:, 1], 0, size[0] - 1) + _h, _w, _c = *size, pcd.shape[-1] - 2 + if _c == 0: + canvas = np.zeros((_h, _w, 1)) + canvas[index_list[:, 1], index_list[:, 0]] = 1.0 + else: + canvas = np.zeros((_h, _w, _c)) + canvas[index_list[:, 1], index_list[:, 0]] = pcd[:, 2:] + + return canvas + + @staticmethod + def img_2d_to_pcd_2d_np(img, norm_coord=True): + assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img)) + assert len(img.shape) == 3 + + _h, _w, _c = img.shape + if norm_coord: + x, y = np.meshgrid( + np.linspace(-1, 1, num=_w), + np.linspace(-1, 1, num=_h), + ) + else: + x, y = np.meshgrid( + np.linspace(0, _w - 1, num=_w), + np.linspace(0, _h - 1, num=_h), + ) + x, y = x.reshape(-1, 1), y.reshape(-1, 1) + feat = img.reshape(-1, _c) + pcd_2d = np.concatenate([x, y, feat], axis=1) + return pcd_2d diff --git a/imcui/third_party/COTR/COTR/sfm_scenes/knn_search.py b/imcui/third_party/COTR/COTR/sfm_scenes/knn_search.py new file mode 100644 index 0000000000000000000000000000000000000000..af2bda01cd48571c8641569f2f7d243288de4899 --- /dev/null +++ b/imcui/third_party/COTR/COTR/sfm_scenes/knn_search.py @@ -0,0 +1,56 @@ +''' +Given one capture in a scene, search for its KNN captures +''' + +import os + +import numpy as np + +from COTR.utils import debug_utils +from COTR.utils.constants import VALID_NN_OVERLAPPING_THRESH + + +class ReprojRatioKnnSearch(): + def __init__(self, scene): + self.scene = scene + self.distance_mat = None + self.nn_index = None + self._read_dist_mat() + self._build_nn_index() + + def _read_dist_mat(self): + dist_mat_path = os.path.join(os.path.dirname(os.path.dirname(self.scene.captures[0].depth_path)), 'dist_mat/dist_mat.npy') + self.distance_mat = np.load(dist_mat_path) + + def _build_nn_index(self): + # argsort is in ascending order, so we take negative + self.nn_index = (-1 * self.distance_mat).argsort(axis=1) + + def get_knn(self, query, k, db_mask=None): + query_index = self.scene.img_path_to_index_dict[query.img_path] + if db_mask is not None: + query_mask = np.setdiff1d(np.arange(self.distance_mat[query_index].shape[0]), db_mask) + num_pos = (self.distance_mat[query_index] > VALID_NN_OVERLAPPING_THRESH).sum() if db_mask is None else (self.distance_mat[query_index][db_mask] > VALID_NN_OVERLAPPING_THRESH).sum() + # we have enough valid NN or not + if num_pos > k: + if db_mask is None: + ind = self.nn_index[query_index][:k + 1] + else: + temp_dist = self.distance_mat[query_index].copy() + temp_dist[query_mask] = -1 + ind = (-1 * temp_dist).argsort(axis=0)[:k + 1] + # remove self + if query_index in ind: + ind = np.delete(ind, np.argwhere(ind == query_index)) + else: + ind = ind[:k] + assert ind.shape[0] <= k, ind.shape[0] > 0 + else: + k = num_pos + if db_mask is None: + ind = self.nn_index[query_index][:max(k, 1)] + else: + temp_dist = self.distance_mat[query_index].copy() + temp_dist[query_mask] = -1 + ind = (-1 * temp_dist).argsort(axis=0)[:max(k, 1)] + return self.scene.get_captures_given_index_list(ind) diff --git a/imcui/third_party/COTR/COTR/sfm_scenes/sfm_scenes.py b/imcui/third_party/COTR/COTR/sfm_scenes/sfm_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3cf0dece9e613073e445eae4e94a14e2145087 --- /dev/null +++ b/imcui/third_party/COTR/COTR/sfm_scenes/sfm_scenes.py @@ -0,0 +1,87 @@ +''' +Scene reconstructed from SFM, mainly colmap +''' +import os +import copy +import math + +import numpy as np +from numpy.linalg import inv +from tqdm import tqdm + +from COTR.transformations import transformations +from COTR.transformations.transform_basics import Translation, Rotation +from COTR.cameras.camera_pose import CameraPose +from COTR.utils import debug_utils + + +class SfmScene(): + def __init__(self, captures, point_cloud=None): + self.captures = captures + if isinstance(point_cloud, tuple): + self.point_cloud = point_cloud[0] + self.point_meta = point_cloud[1] + else: + self.point_cloud = point_cloud + self.img_path_to_index_dict = {} + self.img_id_to_index_dict = {} + self.fname_to_index_dict = {} + self._build_img_X_to_index_dict() + + def __str__(self): + string = 'Scene contains {0} captures'.format(len(self.captures)) + return string + + def __getitem__(self, x): + if isinstance(x, str): + try: + return self.captures[self.img_path_to_index_dict[x]] + except: + return self.captures[self.fname_to_index_dict[x]] + else: + return self.captures[x] + + def _build_img_X_to_index_dict(self): + assert self.captures is not None, 'There is no captures' + for i, cap in enumerate(self.captures): + assert cap.img_path not in self.img_path_to_index_dict, 'Image already exists' + self.img_path_to_index_dict[cap.img_path] = i + assert os.path.basename(cap.img_path) not in self.fname_to_index_dict, 'Image already exists' + self.fname_to_index_dict[os.path.basename(cap.img_path)] = i + if hasattr(cap, 'image_id'): + self.img_id_to_index_dict[cap.image_id] = i + + def get_captures_given_index_list(self, index_list): + captures_list = [] + for i in index_list: + captures_list.append(self.captures[i]) + return captures_list + + def get_covisible_caps(self, cap): + assert cap.img_path in self.img_path_to_index_dict + covis_img_id = set() + point_ids = cap.point3d_id + for i in point_ids: + covis_img_id = covis_img_id.union(set(self.point_meta[i].image_ids)) + covis_caps = [] + for i in covis_img_id: + if i in self.img_id_to_index_dict: + covis_caps.append(self.captures[self.img_id_to_index_dict[i]]) + else: + pass + return covis_caps + + def read_data_to_ram(self, data_list): + print('warning: you are going to use a lot of RAM.') + sum_bytes = 0.0 + pbar = tqdm(self.captures, desc='reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0))) + for cap in pbar: + if 'image' in data_list: + sum_bytes += cap.read_image_to_ram() + if 'depth' in data_list: + sum_bytes += cap.read_depth_to_ram() + if 'pcd' in data_list: + sum_bytes += cap.read_pcd_to_ram() + pbar.set_description('reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0))) + print('----- total memory usage for images: {0} MB-----'.format(sum_bytes / (1024.0 * 1024.0))) + diff --git a/imcui/third_party/COTR/COTR/trainers/base_trainer.py b/imcui/third_party/COTR/COTR/trainers/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..83afea69bb3ec76f1344e62f121efcb8b4c00a8b --- /dev/null +++ b/imcui/third_party/COTR/COTR/trainers/base_trainer.py @@ -0,0 +1,111 @@ +import os +import math +import abc +import time + +import tqdm +import torch.nn as nn +import tensorboardX + +from COTR.trainers import tensorboard_helper +from COTR.utils import utils +from COTR.options import options_utils + + +class BaseTrainer(abc.ABC): + '''base trainer class. + contains methods for training, validation, and writing output. + ''' + + def __init__(self, opt, model, optimizer, criterion, + train_loader, val_loader): + self.opt = opt + self.use_cuda = opt.use_cuda + self.model = model + self.optim = optimizer + self.criterion = criterion + self.train_loader = train_loader + self.val_loader = val_loader + self.out = opt.out + if not os.path.exists(opt.out): + os.makedirs(opt.out) + self.epoch = 0 + self.iteration = 0 + self.max_iter = opt.max_iter + self.valid_iter = opt.valid_iter + self.tb_pusher = tensorboard_helper.TensorboardPusher(opt) + self.push_opt_to_tb() + self.need_resume = opt.resume + if self.need_resume: + self.resume() + if self.opt.load_weights: + self.load_pretrained_weights() + + def push_opt_to_tb(self): + opt_str = options_utils.opt_to_string(self.opt) + tb_datapack = tensorboard_helper.TensorboardDatapack() + tb_datapack.set_training(False) + tb_datapack.set_iteration(self.iteration) + tb_datapack.add_text({'options': opt_str}) + self.tb_pusher.push_to_tensorboard(tb_datapack) + + @abc.abstractmethod + def validate_batch(self, data_pack): + pass + + @abc.abstractmethod + def validate(self): + pass + + @abc.abstractmethod + def train_batch(self, data_pack): + '''train for one batch of data + ''' + pass + + def train_epoch(self): + '''train for one epoch + one epoch is iterating the whole training dataset once + ''' + self.model.train() + for batch_idx, data_pack in tqdm.tqdm(enumerate(self.train_loader), + initial=self.iteration % len( + self.train_loader), + total=len(self.train_loader), + desc='Train epoch={0}'.format( + self.epoch), + ncols=80, + leave=True, + ): + + # iteration = batch_idx + self.epoch * len(self.train_loader) + # if self.iteration != 0 and (iteration - 1) != self.iteration: + # continue # for resuming + # self.iteration = iteration + # self.iteration += 1 + if self.iteration % self.valid_iter == 0: + time.sleep(2) # Prevent possible deadlock during epoch transition + self.validate() + self.train_batch(data_pack) + + if self.iteration >= self.max_iter: + break + self.iteration += 1 + + def train(self): + '''entrance of the whole training process + ''' + max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader))) + for epoch in tqdm.trange(self.epoch, + max_epoch, + desc='Train', + ncols=80): + self.epoch = epoch + time.sleep(2) # Prevent possible deadlock during epoch transition + self.train_epoch() + if self.iteration >= self.max_iter: + break + + @abc.abstractmethod + def resume(self): + pass diff --git a/imcui/third_party/COTR/COTR/trainers/cotr_trainer.py b/imcui/third_party/COTR/COTR/trainers/cotr_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf87677810b875ef5d29984857e11fc8c3cd7b17 --- /dev/null +++ b/imcui/third_party/COTR/COTR/trainers/cotr_trainer.py @@ -0,0 +1,200 @@ +import os +import math +import os.path as osp +import time + +import tqdm +import torch +import numpy as np +import torchvision.utils as vutils +from PIL import Image, ImageDraw + + +from COTR.utils import utils, debug_utils, constants +from COTR.trainers import base_trainer, tensorboard_helper +from COTR.projector import pcd_projector + + +class COTRTrainer(base_trainer.BaseTrainer): + def __init__(self, opt, model, optimizer, criterion, + train_loader, val_loader): + super().__init__(opt, model, optimizer, criterion, + train_loader, val_loader) + + def validate_batch(self, data_pack): + assert self.model.training is False + with torch.no_grad(): + img = data_pack['image'].cuda() + query = data_pack['queries'].cuda() + target = data_pack['targets'].cuda() + self.optim.zero_grad() + pred = self.model(img, query)['pred_corrs'] + loss = torch.nn.functional.mse_loss(pred, target) + if self.opt.cycle_consis and self.opt.bidirectional: + cycle = self.model(img, pred)['pred_corrs'] + mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE + if mask.sum() > 0: + cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask]) + loss += cycle_loss + elif self.opt.cycle_consis and not self.opt.bidirectional: + img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1) + query_reverse = pred.clone() + query_reverse[..., 0] = query_reverse[..., 0] - 0.5 + cycle = self.model(img_reverse, query_reverse)['pred_corrs'] + cycle[..., 0] = cycle[..., 0] - 0.5 + mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE + if mask.sum() > 0: + cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask]) + loss += cycle_loss + loss_data = loss.data.item() + if np.isnan(loss_data): + print('loss is nan while validating') + return loss_data, pred + + def validate(self): + '''validate for whole validation dataset + ''' + training = self.model.training + self.model.eval() + val_loss_list = [] + for batch_idx, data_pack in tqdm.tqdm( + enumerate(self.val_loader), total=len(self.val_loader), + desc='Valid iteration=%d' % self.iteration, ncols=80, + leave=False): + loss_data, pred = self.validate_batch(data_pack) + val_loss_list.append(loss_data) + mean_loss = np.array(val_loss_list).mean() + validation_data = {'val_loss': mean_loss, + 'pred': pred, + } + self.push_validation_data(data_pack, validation_data) + self.save_model() + if training: + self.model.train() + + def save_model(self): + torch.save({ + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optim_state_dict': self.optim.state_dict(), + 'model_state_dict': self.model.state_dict(), + }, osp.join(self.out, 'checkpoint.pth.tar')) + if self.iteration % (10 * self.valid_iter) == 0: + torch.save({ + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optim_state_dict': self.optim.state_dict(), + 'model_state_dict': self.model.state_dict(), + }, osp.join(self.out, f'{self.iteration}_checkpoint.pth.tar')) + + def draw_corrs(self, imgs, corrs, col=(255, 0, 0)): + imgs = utils.torch_img_to_np_img(imgs) + out = [] + for img, corr in zip(imgs, corrs): + img = np.interp(img, [img.min(), img.max()], [0, 255]).astype(np.uint8) + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + corr *= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE]) + for c in corr: + draw.line(c, fill=col) + out.append(np.array(img)) + out = np.array(out) / 255.0 + return utils.np_img_to_torch_img(out) + + def push_validation_data(self, data_pack, validation_data): + val_loss = validation_data['val_loss'] + pred_corrs = np.concatenate([data_pack['queries'].numpy(), validation_data['pred'].cpu().numpy()], axis=-1) + pred_corrs = self.draw_corrs(data_pack['image'], pred_corrs) + gt_corrs = np.concatenate([data_pack['queries'].numpy(), data_pack['targets'].cpu().numpy()], axis=-1) + gt_corrs = self.draw_corrs(data_pack['image'], gt_corrs, (0, 255, 0)) + + gt_img = vutils.make_grid(gt_corrs, normalize=True, scale_each=True) + pred_img = vutils.make_grid(pred_corrs, normalize=True, scale_each=True) + tb_datapack = tensorboard_helper.TensorboardDatapack() + tb_datapack.set_training(False) + tb_datapack.set_iteration(self.iteration) + tb_datapack.add_scalar({'loss/val': val_loss}) + tb_datapack.add_image({'image/gt_corrs': gt_img}) + tb_datapack.add_image({'image/pred_corrs': pred_img}) + self.tb_pusher.push_to_tensorboard(tb_datapack) + + def train_batch(self, data_pack): + '''train for one batch of data + ''' + img = data_pack['image'].cuda() + query = data_pack['queries'].cuda() + target = data_pack['targets'].cuda() + + self.optim.zero_grad() + pred = self.model(img, query)['pred_corrs'] + loss = torch.nn.functional.mse_loss(pred, target) + if self.opt.cycle_consis and self.opt.bidirectional: + cycle = self.model(img, pred)['pred_corrs'] + mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE + if mask.sum() > 0: + cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask]) + loss += cycle_loss + elif self.opt.cycle_consis and not self.opt.bidirectional: + img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1) + query_reverse = pred.clone() + query_reverse[..., 0] = query_reverse[..., 0] - 0.5 + cycle = self.model(img_reverse, query_reverse)['pred_corrs'] + cycle[..., 0] = cycle[..., 0] - 0.5 + mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE + if mask.sum() > 0: + cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask]) + loss += cycle_loss + loss_data = loss.data.item() + if np.isnan(loss_data): + print('loss is nan during training') + self.optim.zero_grad() + else: + loss.backward() + self.push_training_data(data_pack, pred, target, loss) + self.optim.step() + + def push_training_data(self, data_pack, pred, target, loss): + tb_datapack = tensorboard_helper.TensorboardDatapack() + tb_datapack.set_training(True) + tb_datapack.set_iteration(self.iteration) + tb_datapack.add_histogram({'distribution/pred': pred}) + tb_datapack.add_histogram({'distribution/target': target}) + tb_datapack.add_scalar({'loss/train': loss}) + self.tb_pusher.push_to_tensorboard(tb_datapack) + + def resume(self): + '''resume training: + resume from the recorded epoch, iteration, and saved weights. + resume from the model with the same name. + + Arguments: + opt {[type]} -- [description] + ''' + if hasattr(self.opt, 'load_weights'): + assert self.opt.load_weights is None or self.opt.load_weights == False + # 1. load check point + checkpoint_path = os.path.join(self.opt.out, 'checkpoint.pth.tar') + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path) + else: + raise FileNotFoundError( + 'model check point cannnot found: {0}'.format(checkpoint_path)) + # 2. load data + self.epoch = checkpoint['epoch'] + self.iteration = checkpoint['iteration'] + self.load_pretrained_weights() + self.optim.load_state_dict(checkpoint['optim_state_dict']) + + def load_pretrained_weights(self): + ''' + load pretrained weights from another model + ''' + # if hasattr(self.opt, 'resume'): + # assert self.opt.resume is False + assert os.path.isfile(self.opt.load_weights_path), self.opt.load_weights_path + + saved_weights = torch.load(self.opt.load_weights_path)['model_state_dict'] + utils.safe_load_weights(self.model, saved_weights) + content_list = [] + content_list += [f'Loaded pretrained weights from {self.opt.load_weights_path}'] + utils.print_notification(content_list) diff --git a/imcui/third_party/COTR/COTR/trainers/tensorboard_helper.py b/imcui/third_party/COTR/COTR/trainers/tensorboard_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4550fe05c300b1b6c2729c8677e7f402c69fb8e7 --- /dev/null +++ b/imcui/third_party/COTR/COTR/trainers/tensorboard_helper.py @@ -0,0 +1,97 @@ +import abc + +import tensorboardX + + +class TensorboardDatapack(): + '''data dictionary for pushing to tb + ''' + + def __init__(self): + self.SCALAR_NAME = 'scalar' + self.HISTOGRAM_NAME = 'histogram' + self.IMAGE_NAME = 'image' + self.TEXT_NAME = 'text' + self.datapack = {} + self.datapack[self.SCALAR_NAME] = {} + self.datapack[self.HISTOGRAM_NAME] = {} + self.datapack[self.IMAGE_NAME] = {} + self.datapack[self.TEXT_NAME] = {} + + def set_training(self, training): + self.training = training + + def set_iteration(self, iteration): + self.iteration = iteration + + def add_scalar(self, scalar_dict): + self.datapack[self.SCALAR_NAME].update(scalar_dict) + + def add_histogram(self, histogram_dict): + self.datapack[self.HISTOGRAM_NAME].update(histogram_dict) + + def add_image(self, image_dict): + self.datapack[self.IMAGE_NAME].update(image_dict) + + def add_text(self, text_dict): + self.datapack[self.TEXT_NAME].update(text_dict) + + +class TensorboardHelperBase(abc.ABC): + '''abstract base class for tb helpers + ''' + + def __init__(self, tb_writer): + self.tb_writer = tb_writer + + @abc.abstractmethod + def add_data(self, tb_datapack): + pass + + +class TensorboardScalarHelper(TensorboardHelperBase): + def add_data(self, tb_datapack): + scalar_dict = tb_datapack.datapack[tb_datapack.SCALAR_NAME] + for key, val in scalar_dict.items(): + self.tb_writer.add_scalar( + key, val, global_step=tb_datapack.iteration) + + +class TensorboardHistogramHelper(TensorboardHelperBase): + def add_data(self, tb_datapack): + histogram_dict = tb_datapack.datapack[tb_datapack.HISTOGRAM_NAME] + for key, val in histogram_dict.items(): + self.tb_writer.add_histogram( + key, val, global_step=tb_datapack.iteration) + + +class TensorboardImageHelper(TensorboardHelperBase): + def add_data(self, tb_datapack): + image_dict = tb_datapack.datapack[tb_datapack.IMAGE_NAME] + for key, val in image_dict.items(): + self.tb_writer.add_image( + key, val, global_step=tb_datapack.iteration) + + +class TensorboardTextHelper(TensorboardHelperBase): + def add_data(self, tb_datapack): + text_dict = tb_datapack.datapack[tb_datapack.TEXT_NAME] + for key, val in text_dict.items(): + self.tb_writer.add_text( + key, val, global_step=tb_datapack.iteration) + + +class TensorboardPusher(): + def __init__(self, opt): + self.tb_writer = tensorboardX.SummaryWriter(opt.tb_out) + scalar_helper = TensorboardScalarHelper(self.tb_writer) + histogram_helper = TensorboardHistogramHelper(self.tb_writer) + image_helper = TensorboardImageHelper(self.tb_writer) + text_helper = TensorboardTextHelper(self.tb_writer) + self.helper_list = [scalar_helper, + histogram_helper, image_helper, text_helper] + + def push_to_tensorboard(self, tb_datapack): + for helper in self.helper_list: + helper.add_data(tb_datapack) + self.tb_writer.flush() diff --git a/imcui/third_party/COTR/COTR/transformations/transform_basics.py b/imcui/third_party/COTR/COTR/transformations/transform_basics.py new file mode 100644 index 0000000000000000000000000000000000000000..26cdb8068cfb857fba6c680724013c0d4b4721da --- /dev/null +++ b/imcui/third_party/COTR/COTR/transformations/transform_basics.py @@ -0,0 +1,114 @@ +import numpy as np + +from COTR.transformations import transformations +from COTR.utils import constants + + +class Rotation(): + def __init__(self, quat): + """ + quaternion format (w, x, y, z) + """ + assert quat.dtype == np.float32 + self.quaternion = quat + + def __str__(self): + string = '{0}'.format(self.quaternion) + return string + + @classmethod + def from_matrix(cls, mat): + assert isinstance(mat, np.ndarray) + if mat.shape == (3, 3): + id_mat = np.eye(4) + id_mat[0:3, 0:3] = mat + mat = id_mat + assert mat.shape == (4, 4) + quat = transformations.quaternion_from_matrix(mat).astype(constants.DEFAULT_PRECISION) + return cls(quat) + + @property + def rotation_matrix(self): + return transformations.quaternion_matrix(self.quaternion).astype(constants.DEFAULT_PRECISION) + + @rotation_matrix.setter + def rotation_matrix(self, mat): + assert isinstance(mat, np.ndarray) + assert mat.shape == (4, 4) + quat = transformations.quaternion_from_matrix(mat) + self.quaternion = quat + + @property + def quaternion(self): + assert isinstance(self._quaternion, np.ndarray) + assert self._quaternion.shape == (4,) + assert np.isclose(np.linalg.norm(self._quaternion), 1.0), 'self._quaternion is not normalized or valid' + return self._quaternion + + @quaternion.setter + def quaternion(self, quat): + assert isinstance(quat, np.ndarray) + assert quat.shape == (4,) + if not np.isclose(np.linalg.norm(quat), 1.0): + print(f'WARNING: normalizing the input quatternion to unit quaternion: {np.linalg.norm(quat)}') + quat = quat / np.linalg.norm(quat) + assert np.isclose(np.linalg.norm(quat), 1.0), f'input quaternion is not normalized or valid: {quat}' + self._quaternion = quat + + +class UnstableRotation(): + def __init__(self, mat): + assert isinstance(mat, np.ndarray) + if mat.shape == (3, 3): + id_mat = np.eye(4) + id_mat[0:3, 0:3] = mat + mat = id_mat + assert mat.shape == (4, 4) + mat[:3, 3] = 0 + self._rotation_matrix = mat + + def __str__(self): + string = f'rotation_matrix: {self.rotation_matrix}' + return string + + @property + def rotation_matrix(self): + return self._rotation_matrix + + +class Translation(): + def __init__(self, vec): + assert vec.dtype == np.float32 + self.translation_vector = vec + + def __str__(self): + string = '{0}'.format(self.translation_vector) + return string + + @classmethod + def from_matrix(cls, mat): + assert isinstance(mat, np.ndarray) + assert mat.shape == (4, 4) + vec = transformations.translation_from_matrix(mat) + return cls(vec) + + @property + def translation_matrix(self): + return transformations.translation_matrix(self.translation_vector).astype(constants.DEFAULT_PRECISION) + + @translation_matrix.setter + def translation_matrix(self, mat): + assert isinstance(mat, np.ndarray) + assert mat.shape == (4, 4) + vec = transformations.translation_from_matrix(mat) + self.translation_vector = vec + + @property + def translation_vector(self): + return self._translation_vector + + @translation_vector.setter + def translation_vector(self, vec): + assert isinstance(vec, np.ndarray) + assert vec.shape == (3,) + self._translation_vector = vec diff --git a/imcui/third_party/COTR/COTR/transformations/transformations.py b/imcui/third_party/COTR/COTR/transformations/transformations.py new file mode 100644 index 0000000000000000000000000000000000000000..809ce6683c27e641de5b845ac3b5c53d6a3167f6 --- /dev/null +++ b/imcui/third_party/COTR/COTR/transformations/transformations.py @@ -0,0 +1,1951 @@ +# -*- coding: utf-8 -*- +# transformations.py + +# Copyright (c) 2006-2019, Christoph Gohlke +# Copyright (c) 2006-2019, The Regents of the University of California +# Produced at the Laboratory for Fluorescence Dynamics +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +"""Homogeneous Transformation Matrices and Quaternions. + +Transformations is a Python library for calculating 4x4 matrices for +translating, rotating, reflecting, scaling, shearing, projecting, +orthogonalizing, and superimposing arrays of 3D homogeneous coordinates +as well as for converting between rotation matrices, Euler angles, +and quaternions. Also includes an Arcball control object and +functions to decompose transformation matrices. + +:Author: + `Christoph Gohlke `_ + +:Organization: + Laboratory for Fluorescence Dynamics. University of California, Irvine + +:License: 3-clause BSD + +:Version: 2019.2.20 + +Requirements +------------ +* `CPython 2.7 or 3.5+ `_ +* `Numpy 1.14 `_ +* A Python distutils compatible C compiler (build) + +Revisions +--------- +2019.1.1 + Update copyright year. + +Notes +----- +Transformations.py is no longer actively developed and has a few known issues +and numerical instabilities. The module is mostly superseded by other modules +for 3D transformations and quaternions: + +* `Scipy.spatial.transform `_ +* `Transforms3d `_ + (includes most code of this module) +* `Numpy-quaternion `_ +* `Blender.mathutils `_ + +The API is not stable yet and is expected to change between revisions. + +Python 2.7 and 3.4 are deprecated. + +This Python code is not optimized for speed. Refer to the transformations.c +module for a faster implementation of some functions. + +Documentation in HTML format can be generated with epydoc. + +Matrices (M) can be inverted using numpy.linalg.inv(M), be concatenated using +numpy.dot(M0, M1), or transform homogeneous coordinate arrays (v) using +numpy.dot(M, v) for shape (4, \*) column vectors, respectively +numpy.dot(v, M.T) for shape (\*, 4) row vectors ("array of points"). + +This module follows the "column vectors on the right" and "row major storage" +(C contiguous) conventions. The translation components are in the right column +of the transformation matrix, i.e. M[:3, 3]. +The transpose of the transformation matrices may have to be used to interface +with other graphics systems, e.g. OpenGL's glMultMatrixd(). See also [16]. + +Calculations are carried out with numpy.float64 precision. + +Vector, point, quaternion, and matrix function arguments are expected to be +"array like", i.e. tuple, list, or numpy arrays. + +Return types are numpy arrays unless specified otherwise. + +Angles are in radians unless specified otherwise. + +Quaternions w+ix+jy+kz are represented as [w, x, y, z]. + +A triple of Euler angles can be applied/interpreted in 24 ways, which can +be specified using a 4 character string or encoded 4-tuple: + + *Axes 4-string*: e.g. 'sxyz' or 'ryxy' + + - first character : rotations are applied to 's'tatic or 'r'otating frame + - remaining characters : successive rotation axis 'x', 'y', or 'z' + + *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1) + + - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix. + - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed + by 'z', or 'z' is followed by 'x'. Otherwise odd (1). + - repetition : first and last axis are same (1) or different (0). + - frame : rotations are applied to static (0) or rotating (1) frame. + +References +---------- +(1) Matrices and transformations. Ronald Goldman. + In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990. +(2) More matrices and transformations: shear and pseudo-perspective. + Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. +(3) Decomposing a matrix into simple transformations. Spencer Thomas. + In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. +(4) Recovering the data from the transformation matrix. Ronald Goldman. + In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991. +(5) Euler angle conversion. Ken Shoemake. + In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994. +(6) Arcball rotation control. Ken Shoemake. + In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994. +(7) Representing attitude: Euler angles, unit quaternions, and rotation + vectors. James Diebel. 2006. +(8) A discussion of the solution for the best rotation to relate two sets + of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828. +(9) Closed-form solution of absolute orientation using unit quaternions. + BKP Horn. J Opt Soc Am A. 1987. 4(4):629-642. +(10) Quaternions. Ken Shoemake. + http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf +(11) From quaternion to matrix and back. JMP van Waveren. 2005. + http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm +(12) Uniform random rotations. Ken Shoemake. + In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992. +(13) Quaternion in molecular modeling. CFF Karney. + J Mol Graph Mod, 25(5):595-604 +(14) New method for extracting the quaternion from a rotation matrix. + Itzhack Y Bar-Itzhack, J Guid Contr Dynam. 2000. 23(6): 1085-1087. +(15) Multiple View Geometry in Computer Vision. Hartley and Zissermann. + Cambridge University Press; 2nd Ed. 2004. Chapter 4, Algorithm 4.7, p 130. +(16) Column Vectors vs. Row Vectors. + http://steve.hollasch.net/cgindex/math/matrix/column-vec.html + +Examples +-------- +>>> alpha, beta, gamma = 0.123, -1.234, 2.345 +>>> origin, xaxis, yaxis, zaxis = [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1] +>>> I = identity_matrix() +>>> Rx = rotation_matrix(alpha, xaxis) +>>> Ry = rotation_matrix(beta, yaxis) +>>> Rz = rotation_matrix(gamma, zaxis) +>>> R = concatenate_matrices(Rx, Ry, Rz) +>>> euler = euler_from_matrix(R, 'rxyz') +>>> numpy.allclose([alpha, beta, gamma], euler) +True +>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz') +>>> is_same_transform(R, Re) +True +>>> al, be, ga = euler_from_matrix(Re, 'rxyz') +>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz')) +True +>>> qx = quaternion_about_axis(alpha, xaxis) +>>> qy = quaternion_about_axis(beta, yaxis) +>>> qz = quaternion_about_axis(gamma, zaxis) +>>> q = quaternion_multiply(qx, qy) +>>> q = quaternion_multiply(q, qz) +>>> Rq = quaternion_matrix(q) +>>> is_same_transform(R, Rq) +True +>>> S = scale_matrix(1.23, origin) +>>> T = translation_matrix([1, 2, 3]) +>>> Z = shear_matrix(beta, xaxis, origin, zaxis) +>>> R = random_rotation_matrix(numpy.random.rand(3)) +>>> M = concatenate_matrices(T, R, Z, S) +>>> scale, shear, angles, trans, persp = decompose_matrix(M) +>>> numpy.allclose(scale, 1.23) +True +>>> numpy.allclose(trans, [1, 2, 3]) +True +>>> numpy.allclose(shear, [0, math.tan(beta), 0]) +True +>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles)) +True +>>> M1 = compose_matrix(scale, shear, angles, trans, persp) +>>> is_same_transform(M, M1) +True +>>> v0, v1 = random_vector(3), random_vector(3) +>>> M = rotation_matrix(angle_between_vectors(v0, v1), vector_product(v0, v1)) +>>> v2 = numpy.dot(v0, M[:3,:3].T) +>>> numpy.allclose(unit_vector(v1), unit_vector(v2)) +True + +""" + +from __future__ import division, print_function + +__version__ = '2019.2.20' +__docformat__ = 'restructuredtext en' + +import math + +import numpy + + +def identity_matrix(): + """Return 4x4 identity/unit matrix. + + >>> I = identity_matrix() + >>> numpy.allclose(I, numpy.dot(I, I)) + True + >>> numpy.sum(I), numpy.trace(I) + (4.0, 4.0) + >>> numpy.allclose(I, numpy.identity(4)) + True + + """ + return numpy.identity(4) + + +def translation_matrix(direction): + """Return matrix to translate by direction vector. + + >>> v = numpy.random.random(3) - 0.5 + >>> numpy.allclose(v, translation_matrix(v)[:3, 3]) + True + + """ + M = numpy.identity(4) + M[:3, 3] = direction[:3] + return M + + +def translation_from_matrix(matrix): + """Return translation vector from translation matrix. + + >>> v0 = numpy.random.random(3) - 0.5 + >>> v1 = translation_from_matrix(translation_matrix(v0)) + >>> numpy.allclose(v0, v1) + True + + """ + return numpy.array(matrix, copy=False)[:3, 3].copy() + + +def reflection_matrix(point, normal): + """Return matrix to mirror at plane defined by point and normal vector. + + >>> v0 = numpy.random.random(4) - 0.5 + >>> v0[3] = 1. + >>> v1 = numpy.random.random(3) - 0.5 + >>> R = reflection_matrix(v0, v1) + >>> numpy.allclose(2, numpy.trace(R)) + True + >>> numpy.allclose(v0, numpy.dot(R, v0)) + True + >>> v2 = v0.copy() + >>> v2[:3] += v1 + >>> v3 = v0.copy() + >>> v2[:3] -= v1 + >>> numpy.allclose(v2, numpy.dot(R, v3)) + True + + """ + normal = unit_vector(normal[:3]) + M = numpy.identity(4) + M[:3, :3] -= 2.0 * numpy.outer(normal, normal) + M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal + return M + + +def reflection_from_matrix(matrix): + """Return mirror plane point and normal vector from reflection matrix. + + >>> v0 = numpy.random.random(3) - 0.5 + >>> v1 = numpy.random.random(3) - 0.5 + >>> M0 = reflection_matrix(v0, v1) + >>> point, normal = reflection_from_matrix(M0) + >>> M1 = reflection_matrix(point, normal) + >>> is_same_transform(M0, M1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + # normal: unit eigenvector corresponding to eigenvalue -1 + w, V = numpy.linalg.eig(M[:3, :3]) + i = numpy.where(abs(numpy.real(w) + 1.0) < 1e-8)[0] + if not len(i): + raise ValueError('no unit eigenvector corresponding to eigenvalue -1') + normal = numpy.real(V[:, i[0]]).squeeze() + # point: any unit eigenvector corresponding to eigenvalue 1 + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError('no unit eigenvector corresponding to eigenvalue 1') + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + return point, normal + + +def rotation_matrix(angle, direction, point=None): + """Return matrix to rotate about axis defined by point and direction. + + >>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0]) + >>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1]) + True + >>> angle = (random.random() - 0.5) * (2*math.pi) + >>> direc = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> R0 = rotation_matrix(angle, direc, point) + >>> R1 = rotation_matrix(angle-2*math.pi, direc, point) + >>> is_same_transform(R0, R1) + True + >>> R0 = rotation_matrix(angle, direc, point) + >>> R1 = rotation_matrix(-angle, -direc, point) + >>> is_same_transform(R0, R1) + True + >>> I = numpy.identity(4, numpy.float64) + >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc)) + True + >>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2, + ... direc, point))) + True + + """ + sina = math.sin(angle) + cosa = math.cos(angle) + direction = unit_vector(direction[:3]) + # rotation matrix around unit vector + R = numpy.diag([cosa, cosa, cosa]) + R += numpy.outer(direction, direction) * (1.0 - cosa) + direction *= sina + R += numpy.array([[0.0, -direction[2], direction[1]], + [direction[2], 0.0, -direction[0]], + [-direction[1], direction[0], 0.0]]) + M = numpy.identity(4) + M[:3, :3] = R + if point is not None: + # rotation not around origin + point = numpy.array(point[:3], dtype=numpy.float64, copy=False) + M[:3, 3] = point - numpy.dot(R, point) + return M + + +def rotation_from_matrix(matrix): + """Return rotation angle and axis from rotation matrix. + + >>> angle = (random.random() - 0.5) * (2*math.pi) + >>> direc = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> R0 = rotation_matrix(angle, direc, point) + >>> angle, direc, point = rotation_from_matrix(R0) + >>> R1 = rotation_matrix(angle, direc, point) + >>> is_same_transform(R0, R1) + True + + """ + R = numpy.array(matrix, dtype=numpy.float64, copy=False) + R33 = R[:3, :3] + # direction: unit eigenvector of R33 corresponding to eigenvalue of 1 + w, W = numpy.linalg.eig(R33.T) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError('no unit eigenvector corresponding to eigenvalue 1') + direction = numpy.real(W[:, i[-1]]).squeeze() + # point: unit eigenvector of R33 corresponding to eigenvalue of 1 + w, Q = numpy.linalg.eig(R) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError('no unit eigenvector corresponding to eigenvalue 1') + point = numpy.real(Q[:, i[-1]]).squeeze() + point /= point[3] + # rotation angle depending on direction + cosa = (numpy.trace(R33) - 1.0) / 2.0 + if abs(direction[2]) > 1e-8: + sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2] + elif abs(direction[1]) > 1e-8: + sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1] + else: + sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0] + angle = math.atan2(sina, cosa) + return angle, direction, point + + +def scale_matrix(factor, origin=None, direction=None): + """Return matrix to scale by factor around origin in direction. + + Use factor -1 for point symmetry. + + >>> v = (numpy.random.rand(4, 5) - 0.5) * 20 + >>> v[3] = 1 + >>> S = scale_matrix(-1.234) + >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3]) + True + >>> factor = random.random() * 10 - 5 + >>> origin = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> S = scale_matrix(factor, origin) + >>> S = scale_matrix(factor, origin, direct) + + """ + if direction is None: + # uniform scaling + M = numpy.diag([factor, factor, factor, 1.0]) + if origin is not None: + M[:3, 3] = origin[:3] + M[:3, 3] *= 1.0 - factor + else: + # nonuniform scaling + direction = unit_vector(direction[:3]) + factor = 1.0 - factor + M = numpy.identity(4) + M[:3, :3] -= factor * numpy.outer(direction, direction) + if origin is not None: + M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction + return M + + +def scale_from_matrix(matrix): + """Return scaling factor, origin and direction from scaling matrix. + + >>> factor = random.random() * 10 - 5 + >>> origin = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> S0 = scale_matrix(factor, origin) + >>> factor, origin, direction = scale_from_matrix(S0) + >>> S1 = scale_matrix(factor, origin, direction) + >>> is_same_transform(S0, S1) + True + >>> S0 = scale_matrix(factor, origin, direct) + >>> factor, origin, direction = scale_from_matrix(S0) + >>> S1 = scale_matrix(factor, origin, direction) + >>> is_same_transform(S0, S1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + factor = numpy.trace(M33) - 2.0 + try: + # direction: unit eigenvector corresponding to eigenvalue factor + w, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(w) - factor) < 1e-8)[0][0] + direction = numpy.real(V[:, i]).squeeze() + direction /= vector_norm(direction) + except IndexError: + # uniform scaling + factor = (factor + 2.0) / 3.0 + direction = None + # origin: any eigenvector corresponding to eigenvalue 1 + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError('no eigenvector corresponding to eigenvalue 1') + origin = numpy.real(V[:, i[-1]]).squeeze() + origin /= origin[3] + return factor, origin, direction + + +def projection_matrix(point, normal, direction=None, + perspective=None, pseudo=False): + """Return matrix to project onto plane defined by point and normal. + + Using either perspective point, projection direction, or none of both. + + If pseudo is True, perspective projections will preserve relative depth + such that Perspective = dot(Orthogonal, PseudoPerspective). + + >>> P = projection_matrix([0, 0, 0], [1, 0, 0]) + >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:]) + True + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(3) - 0.5 + >>> P0 = projection_matrix(point, normal) + >>> P1 = projection_matrix(point, normal, direction=direct) + >>> P2 = projection_matrix(point, normal, perspective=persp) + >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True) + >>> is_same_transform(P2, numpy.dot(P0, P3)) + True + >>> P = projection_matrix([3, 0, 0], [1, 1, 0], [1, 0, 0]) + >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20 + >>> v0[3] = 1 + >>> v1 = numpy.dot(P, v0) + >>> numpy.allclose(v1[1], v0[1]) + True + >>> numpy.allclose(v1[0], 3-v1[1]) + True + + """ + M = numpy.identity(4) + point = numpy.array(point[:3], dtype=numpy.float64, copy=False) + normal = unit_vector(normal[:3]) + if perspective is not None: + # perspective projection + perspective = numpy.array(perspective[:3], dtype=numpy.float64, + copy=False) + M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal) + M[:3, :3] -= numpy.outer(perspective, normal) + if pseudo: + # preserve relative depth + M[:3, :3] -= numpy.outer(normal, normal) + M[:3, 3] = numpy.dot(point, normal) * (perspective+normal) + else: + M[:3, 3] = numpy.dot(point, normal) * perspective + M[3, :3] = -normal + M[3, 3] = numpy.dot(perspective, normal) + elif direction is not None: + # parallel projection + direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False) + scale = numpy.dot(direction, normal) + M[:3, :3] -= numpy.outer(direction, normal) / scale + M[:3, 3] = direction * (numpy.dot(point, normal) / scale) + else: + # orthogonal projection + M[:3, :3] -= numpy.outer(normal, normal) + M[:3, 3] = numpy.dot(point, normal) * normal + return M + + +def projection_from_matrix(matrix, pseudo=False): + """Return projection plane and perspective point from projection matrix. + + Return values are same as arguments for projection_matrix function: + point, normal, direction, perspective, and pseudo. + + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(3) - 0.5 + >>> P0 = projection_matrix(point, normal) + >>> result = projection_from_matrix(P0) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, direct) + >>> result = projection_from_matrix(P0) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False) + >>> result = projection_from_matrix(P0, pseudo=False) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True) + >>> result = projection_from_matrix(P0, pseudo=True) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not pseudo and len(i): + # point: any eigenvector corresponding to eigenvalue 1 + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + # direction: unit eigenvector corresponding to eigenvalue 0 + w, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(w)) < 1e-8)[0] + if not len(i): + raise ValueError('no eigenvector corresponding to eigenvalue 0') + direction = numpy.real(V[:, i[0]]).squeeze() + direction /= vector_norm(direction) + # normal: unit eigenvector of M33.T corresponding to eigenvalue 0 + w, V = numpy.linalg.eig(M33.T) + i = numpy.where(abs(numpy.real(w)) < 1e-8)[0] + if len(i): + # parallel projection + normal = numpy.real(V[:, i[0]]).squeeze() + normal /= vector_norm(normal) + return point, normal, direction, None, False + else: + # orthogonal projection, where normal equals direction vector + return point, direction, None, None, False + else: + # perspective projection + i = numpy.where(abs(numpy.real(w)) > 1e-8)[0] + if not len(i): + raise ValueError( + 'no eigenvector not corresponding to eigenvalue 0') + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + normal = - M[3, :3] + perspective = M[:3, 3] / numpy.dot(point[:3], normal) + if pseudo: + perspective -= normal + return point, normal, None, perspective, pseudo + + +def clip_matrix(left, right, bottom, top, near, far, perspective=False): + """Return matrix to obtain normalized device coordinates from frustum. + + The frustum bounds are axis-aligned along x (left, right), + y (bottom, top) and z (near, far). + + Normalized device coordinates are in range [-1, 1] if coordinates are + inside the frustum. + + If perspective is True the frustum is a truncated pyramid with the + perspective point at origin and direction along z axis, otherwise an + orthographic canonical view volume (a box). + + Homogeneous coordinates transformed by the perspective clip matrix + need to be dehomogenized (divided by w coordinate). + + >>> frustum = numpy.random.rand(6) + >>> frustum[1] += frustum[0] + >>> frustum[3] += frustum[2] + >>> frustum[5] += frustum[4] + >>> M = clip_matrix(perspective=False, *frustum) + >>> numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1]) + array([-1., -1., -1., 1.]) + >>> numpy.dot(M, [frustum[1], frustum[3], frustum[5], 1]) + array([ 1., 1., 1., 1.]) + >>> M = clip_matrix(perspective=True, *frustum) + >>> v = numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1]) + >>> v / v[3] + array([-1., -1., -1., 1.]) + >>> v = numpy.dot(M, [frustum[1], frustum[3], frustum[4], 1]) + >>> v / v[3] + array([ 1., 1., -1., 1.]) + + """ + if left >= right or bottom >= top or near >= far: + raise ValueError('invalid frustum') + if perspective: + if near <= _EPS: + raise ValueError('invalid frustum: near <= 0') + t = 2.0 * near + M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0], + [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0], + [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)], + [0.0, 0.0, -1.0, 0.0]] + else: + M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)], + [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)], + [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)], + [0.0, 0.0, 0.0, 1.0]] + return numpy.array(M) + + +def shear_matrix(angle, direction, point, normal): + """Return matrix to shear by angle along direction vector on shear plane. + + The shear plane is defined by a point and normal vector. The direction + vector must be orthogonal to the plane's normal vector. + + A point P is transformed by the shear matrix into P" such that + the vector P-P" is parallel to the direction vector and its extent is + given by the angle of P-P'-P", where P' is the orthogonal projection + of P onto the shear plane. + + >>> angle = (random.random() - 0.5) * 4*math.pi + >>> direct = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.cross(direct, numpy.random.random(3)) + >>> S = shear_matrix(angle, direct, point, normal) + >>> numpy.allclose(1, numpy.linalg.det(S)) + True + + """ + normal = unit_vector(normal[:3]) + direction = unit_vector(direction[:3]) + if abs(numpy.dot(normal, direction)) > 1e-6: + raise ValueError('direction and normal vectors are not orthogonal') + angle = math.tan(angle) + M = numpy.identity(4) + M[:3, :3] += angle * numpy.outer(direction, normal) + M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction + return M + + +def shear_from_matrix(matrix): + """Return shear angle, direction and plane from shear matrix. + + >>> angle = (random.random() - 0.5) * 4*math.pi + >>> direct = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.cross(direct, numpy.random.random(3)) + >>> S0 = shear_matrix(angle, direct, point, normal) + >>> angle, direct, point, normal = shear_from_matrix(S0) + >>> S1 = shear_matrix(angle, direct, point, normal) + >>> is_same_transform(S0, S1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + # normal: cross independent eigenvectors corresponding to the eigenvalue 1 + w, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-4)[0] + if len(i) < 2: + raise ValueError('no two linear independent eigenvectors found %s' % w) + V = numpy.real(V[:, i]).squeeze().T + lenorm = -1.0 + for i0, i1 in ((0, 1), (0, 2), (1, 2)): + n = numpy.cross(V[i0], V[i1]) + w = vector_norm(n) + if w > lenorm: + lenorm = w + normal = n + normal /= lenorm + # direction and angle + direction = numpy.dot(M33 - numpy.identity(3), normal) + angle = vector_norm(direction) + direction /= angle + angle = math.atan(angle) + # point: eigenvector corresponding to eigenvalue 1 + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError('no eigenvector corresponding to eigenvalue 1') + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + return angle, direction, point, normal + + +def decompose_matrix(matrix): + """Return sequence of transformations from transformation matrix. + + matrix : array_like + Non-degenerative homogeneous transformation matrix + + Return tuple of: + scale : vector of 3 scaling factors + shear : list of shear factors for x-y, x-z, y-z axes + angles : list of Euler angles about static x, y, z axes + translate : translation vector along x, y, z axes + perspective : perspective partition of matrix + + Raise ValueError if matrix is of wrong type or degenerative. + + >>> T0 = translation_matrix([1, 2, 3]) + >>> scale, shear, angles, trans, persp = decompose_matrix(T0) + >>> T1 = translation_matrix(trans) + >>> numpy.allclose(T0, T1) + True + >>> S = scale_matrix(0.123) + >>> scale, shear, angles, trans, persp = decompose_matrix(S) + >>> scale[0] + 0.123 + >>> R0 = euler_matrix(1, 2, 3) + >>> scale, shear, angles, trans, persp = decompose_matrix(R0) + >>> R1 = euler_matrix(*angles) + >>> numpy.allclose(R0, R1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=True).T + if abs(M[3, 3]) < _EPS: + raise ValueError('M[3, 3] is zero') + M /= M[3, 3] + P = M.copy() + P[:, 3] = 0.0, 0.0, 0.0, 1.0 + if not numpy.linalg.det(P): + raise ValueError('matrix is singular') + + scale = numpy.zeros((3, )) + shear = [0.0, 0.0, 0.0] + angles = [0.0, 0.0, 0.0] + + if any(abs(M[:3, 3]) > _EPS): + perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T)) + M[:, 3] = 0.0, 0.0, 0.0, 1.0 + else: + perspective = numpy.array([0.0, 0.0, 0.0, 1.0]) + + translate = M[3, :3].copy() + M[3, :3] = 0.0 + + row = M[:3, :3].copy() + scale[0] = vector_norm(row[0]) + row[0] /= scale[0] + shear[0] = numpy.dot(row[0], row[1]) + row[1] -= row[0] * shear[0] + scale[1] = vector_norm(row[1]) + row[1] /= scale[1] + shear[0] /= scale[1] + shear[1] = numpy.dot(row[0], row[2]) + row[2] -= row[0] * shear[1] + shear[2] = numpy.dot(row[1], row[2]) + row[2] -= row[1] * shear[2] + scale[2] = vector_norm(row[2]) + row[2] /= scale[2] + shear[1:] /= scale[2] + + if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0: + numpy.negative(scale, scale) + numpy.negative(row, row) + + angles[1] = math.asin(-row[0, 2]) + if math.cos(angles[1]): + angles[0] = math.atan2(row[1, 2], row[2, 2]) + angles[2] = math.atan2(row[0, 1], row[0, 0]) + else: + # angles[0] = math.atan2(row[1, 0], row[1, 1]) + angles[0] = math.atan2(-row[2, 1], row[1, 1]) + angles[2] = 0.0 + + return scale, shear, angles, translate, perspective + + +def compose_matrix(scale=None, shear=None, angles=None, translate=None, + perspective=None): + """Return transformation matrix from sequence of transformations. + + This is the inverse of the decompose_matrix function. + + Sequence of transformations: + scale : vector of 3 scaling factors + shear : list of shear factors for x-y, x-z, y-z axes + angles : list of Euler angles about static x, y, z axes + translate : translation vector along x, y, z axes + perspective : perspective partition of matrix + + >>> scale = numpy.random.random(3) - 0.5 + >>> shear = numpy.random.random(3) - 0.5 + >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi) + >>> trans = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(4) - 0.5 + >>> M0 = compose_matrix(scale, shear, angles, trans, persp) + >>> result = decompose_matrix(M0) + >>> M1 = compose_matrix(*result) + >>> is_same_transform(M0, M1) + True + + """ + M = numpy.identity(4) + if perspective is not None: + P = numpy.identity(4) + P[3, :] = perspective[:4] + M = numpy.dot(M, P) + if translate is not None: + T = numpy.identity(4) + T[:3, 3] = translate[:3] + M = numpy.dot(M, T) + if angles is not None: + R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz') + M = numpy.dot(M, R) + if shear is not None: + Z = numpy.identity(4) + Z[1, 2] = shear[2] + Z[0, 2] = shear[1] + Z[0, 1] = shear[0] + M = numpy.dot(M, Z) + if scale is not None: + S = numpy.identity(4) + S[0, 0] = scale[0] + S[1, 1] = scale[1] + S[2, 2] = scale[2] + M = numpy.dot(M, S) + M /= M[3, 3] + return M + + +def orthogonalization_matrix(lengths, angles): + """Return orthogonalization matrix for crystallographic cell coordinates. + + Angles are expected in degrees. + + The de-orthogonalization matrix is the inverse. + + >>> O = orthogonalization_matrix([10, 10, 10], [90, 90, 90]) + >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10) + True + >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7]) + >>> numpy.allclose(numpy.sum(O), 43.063229) + True + + """ + a, b, c = lengths + angles = numpy.radians(angles) + sina, sinb, _ = numpy.sin(angles) + cosa, cosb, cosg = numpy.cos(angles) + co = (cosa * cosb - cosg) / (sina * sinb) + return numpy.array([ + [a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0], + [-a*sinb*co, b*sina, 0.0, 0.0], + [a*cosb, b*cosa, c, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + + +def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): + """Return affine transform matrix to register two point sets. + + v0 and v1 are shape (ndims, \*) arrays of at least ndims non-homogeneous + coordinates, where ndims is the dimensionality of the coordinate space. + + If shear is False, a similarity transformation matrix is returned. + If also scale is False, a rigid/Euclidean transformation matrix + is returned. + + By default the algorithm by Hartley and Zissermann [15] is used. + If usesvd is True, similarity and Euclidean transformation matrices + are calculated by minimizing the weighted sum of squared deviations + (RMSD) according to the algorithm by Kabsch [8]. + Otherwise, and if ndims is 3, the quaternion based algorithm by Horn [9] + is used, which is slower when using this Python implementation. + + The returned matrix performs rotation, translation and uniform scaling + (if specified). + + >>> v0 = [[0, 1031, 1031, 0], [0, 0, 1600, 1600]] + >>> v1 = [[675, 826, 826, 677], [55, 52, 281, 277]] + >>> affine_matrix_from_points(v0, v1) + array([[ 0.14549, 0.00062, 675.50008], + [ 0.00048, 0.14094, 53.24971], + [ 0. , 0. , 1. ]]) + >>> T = translation_matrix(numpy.random.random(3)-0.5) + >>> R = random_rotation_matrix(numpy.random.random(3)) + >>> S = scale_matrix(random.random()) + >>> M = concatenate_matrices(T, R, S) + >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20 + >>> v0[3] = 1 + >>> v1 = numpy.dot(M, v0) + >>> v0[:3] += numpy.random.normal(0, 1e-8, 300).reshape(3, -1) + >>> M = affine_matrix_from_points(v0[:3], v1[:3]) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + + More examples in superimposition_matrix() + + """ + v0 = numpy.array(v0, dtype=numpy.float64, copy=True) + v1 = numpy.array(v1, dtype=numpy.float64, copy=True) + + ndims = v0.shape[0] + if ndims < 2 or v0.shape[1] < ndims or v0.shape != v1.shape: + raise ValueError('input arrays are of wrong shape or type') + + # move centroids to origin + t0 = -numpy.mean(v0, axis=1) + M0 = numpy.identity(ndims+1) + M0[:ndims, ndims] = t0 + v0 += t0.reshape(ndims, 1) + t1 = -numpy.mean(v1, axis=1) + M1 = numpy.identity(ndims+1) + M1[:ndims, ndims] = t1 + v1 += t1.reshape(ndims, 1) + + if shear: + # Affine transformation + A = numpy.concatenate((v0, v1), axis=0) + u, s, vh = numpy.linalg.svd(A.T) + vh = vh[:ndims].T + B = vh[:ndims] + C = vh[ndims:2*ndims] + t = numpy.dot(C, numpy.linalg.pinv(B)) + t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1) + M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,))) + elif usesvd or ndims != 3: + # Rigid transformation via SVD of covariance matrix + u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T)) + # rotation matrix from SVD orthonormal bases + R = numpy.dot(u, vh) + if numpy.linalg.det(R) < 0.0: + # R does not constitute right handed system + R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0) + s[-1] *= -1.0 + # homogeneous transformation matrix + M = numpy.identity(ndims+1) + M[:ndims, :ndims] = R + else: + # Rigid transformation matrix via quaternion + # compute symmetric matrix N + xx, yy, zz = numpy.sum(v0 * v1, axis=1) + xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1) + xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1) + N = [[xx+yy+zz, 0.0, 0.0, 0.0], + [yz-zy, xx-yy-zz, 0.0, 0.0], + [zx-xz, xy+yx, yy-xx-zz, 0.0], + [xy-yx, zx+xz, yz+zy, zz-xx-yy]] + # quaternion: eigenvector corresponding to most positive eigenvalue + w, V = numpy.linalg.eigh(N) + q = V[:, numpy.argmax(w)] + q /= vector_norm(q) # unit quaternion + # homogeneous transformation matrix + M = quaternion_matrix(q) + + if scale and not shear: + # Affine transformation; scale is ratio of RMS deviations from centroid + v0 *= v0 + v1 *= v1 + M[:ndims, :ndims] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0)) + + # move centroids back + M = numpy.dot(numpy.linalg.inv(M1), numpy.dot(M, M0)) + M /= M[ndims, ndims] + return M + + +def superimposition_matrix(v0, v1, scale=False, usesvd=True): + """Return matrix to transform given 3D point set into second point set. + + v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 points. + + The parameters scale and usesvd are explained in the more general + affine_matrix_from_points function. + + The returned matrix is a similarity or Euclidean transformation matrix. + This function has a fast C implementation in transformations.c. + + >>> v0 = numpy.random.rand(3, 10) + >>> M = superimposition_matrix(v0, v0) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> R = random_rotation_matrix(numpy.random.random(3)) + >>> v0 = [[1,0,0], [0,1,0], [0,0,1], [1,1,1]] + >>> v1 = numpy.dot(R, v0) + >>> M = superimposition_matrix(v0, v1) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20 + >>> v0[3] = 1 + >>> v1 = numpy.dot(R, v0) + >>> M = superimposition_matrix(v0, v1) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> S = scale_matrix(random.random()) + >>> T = translation_matrix(numpy.random.random(3)-0.5) + >>> M = concatenate_matrices(T, R, S) + >>> v1 = numpy.dot(M, v0) + >>> v0[:3] += numpy.random.normal(0, 1e-9, 300).reshape(3, -1) + >>> M = superimposition_matrix(v0, v1, scale=True) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> v = numpy.empty((4, 100, 3)) + >>> v[:, :, 0] = v0 + >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False) + >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0])) + True + + """ + v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3] + v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3] + return affine_matrix_from_points(v0, v1, shear=False, + scale=scale, usesvd=usesvd) + + +def euler_matrix(ai, aj, ak, axes='sxyz'): + """Return homogeneous rotation matrix from Euler angles and axis sequence. + + ai, aj, ak : Euler's roll, pitch and yaw angles + axes : One of 24 axis sequences as string or encoded tuple + + >>> R = euler_matrix(1, 2, 3, 'syxz') + >>> numpy.allclose(numpy.sum(R[0]), -1.34786452) + True + >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1)) + >>> numpy.allclose(numpy.sum(R[0]), -0.383436184) + True + >>> ai, aj, ak = (4*math.pi) * (numpy.random.random(3) - 0.5) + >>> for axes in _AXES2TUPLE.keys(): + ... R = euler_matrix(ai, aj, ak, axes) + >>> for axes in _TUPLE2AXES.keys(): + ... R = euler_matrix(ai, aj, ak, axes) + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes] + except (AttributeError, KeyError): + _TUPLE2AXES[axes] # noqa: validation + firstaxis, parity, repetition, frame = axes + + i = firstaxis + j = _NEXT_AXIS[i+parity] + k = _NEXT_AXIS[i-parity+1] + + if frame: + ai, ak = ak, ai + if parity: + ai, aj, ak = -ai, -aj, -ak + + si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak) + ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak) + cc, cs = ci*ck, ci*sk + sc, ss = si*ck, si*sk + + M = numpy.identity(4) + if repetition: + M[i, i] = cj + M[i, j] = sj*si + M[i, k] = sj*ci + M[j, i] = sj*sk + M[j, j] = -cj*ss+cc + M[j, k] = -cj*cs-sc + M[k, i] = -sj*ck + M[k, j] = cj*sc+cs + M[k, k] = cj*cc-ss + else: + M[i, i] = cj*ck + M[i, j] = sj*sc-cs + M[i, k] = sj*cc+ss + M[j, i] = cj*sk + M[j, j] = sj*ss+cc + M[j, k] = sj*cs-sc + M[k, i] = -sj + M[k, j] = cj*si + M[k, k] = cj*ci + return M + + +def euler_from_matrix(matrix, axes='sxyz'): + """Return Euler angles from rotation matrix for specified axis sequence. + + axes : One of 24 axis sequences as string or encoded tuple + + Note that many Euler angle triplets can describe one matrix. + + >>> R0 = euler_matrix(1, 2, 3, 'syxz') + >>> al, be, ga = euler_from_matrix(R0, 'syxz') + >>> R1 = euler_matrix(al, be, ga, 'syxz') + >>> numpy.allclose(R0, R1) + True + >>> angles = (4*math.pi) * (numpy.random.random(3) - 0.5) + >>> for axes in _AXES2TUPLE.keys(): + ... R0 = euler_matrix(axes=axes, *angles) + ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes)) + ... if not numpy.allclose(R0, R1): print(axes, "failed") + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] + except (AttributeError, KeyError): + _TUPLE2AXES[axes] # noqa: validation + firstaxis, parity, repetition, frame = axes + + i = firstaxis + j = _NEXT_AXIS[i+parity] + k = _NEXT_AXIS[i-parity+1] + + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3] + if repetition: + sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k]) + if sy > _EPS: + ax = math.atan2(M[i, j], M[i, k]) + ay = math.atan2(sy, M[i, i]) + az = math.atan2(M[j, i], -M[k, i]) + else: + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2(sy, M[i, i]) + az = 0.0 + else: + cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i]) + if cy > _EPS: + ax = math.atan2(M[k, j], M[k, k]) + ay = math.atan2(-M[k, i], cy) + az = math.atan2(M[j, i], M[i, i]) + else: + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2(-M[k, i], cy) + az = 0.0 + + if parity: + ax, ay, az = -ax, -ay, -az + if frame: + ax, az = az, ax + return ax, ay, az + + +def euler_from_quaternion(quaternion, axes='sxyz'): + """Return Euler angles from quaternion for specified axis sequence. + + >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0]) + >>> numpy.allclose(angles, [0.123, 0, 0]) + True + + """ + return euler_from_matrix(quaternion_matrix(quaternion), axes) + + +def quaternion_from_euler(ai, aj, ak, axes='sxyz'): + """Return quaternion from Euler angles and axis sequence. + + ai, aj, ak : Euler's roll, pitch and yaw angles + axes : One of 24 axis sequences as string or encoded tuple + + >>> q = quaternion_from_euler(1, 2, 3, 'ryxz') + >>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435]) + True + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] + except (AttributeError, KeyError): + _TUPLE2AXES[axes] # noqa: validation + firstaxis, parity, repetition, frame = axes + + i = firstaxis + 1 + j = _NEXT_AXIS[i+parity-1] + 1 + k = _NEXT_AXIS[i-parity] + 1 + + if frame: + ai, ak = ak, ai + if parity: + aj = -aj + + ai /= 2.0 + aj /= 2.0 + ak /= 2.0 + ci = math.cos(ai) + si = math.sin(ai) + cj = math.cos(aj) + sj = math.sin(aj) + ck = math.cos(ak) + sk = math.sin(ak) + cc = ci*ck + cs = ci*sk + sc = si*ck + ss = si*sk + + q = numpy.empty((4, )) + if repetition: + q[0] = cj*(cc - ss) + q[i] = cj*(cs + sc) + q[j] = sj*(cc + ss) + q[k] = sj*(cs - sc) + else: + q[0] = cj*cc + sj*ss + q[i] = cj*sc - sj*cs + q[j] = cj*ss + sj*cc + q[k] = cj*cs - sj*sc + if parity: + q[j] *= -1.0 + + return q + + +def quaternion_about_axis(angle, axis): + """Return quaternion for rotation about axis. + + >>> q = quaternion_about_axis(0.123, [1, 0, 0]) + >>> numpy.allclose(q, [0.99810947, 0.06146124, 0, 0]) + True + + """ + q = numpy.array([0.0, axis[0], axis[1], axis[2]]) + qlen = vector_norm(q) + if qlen > _EPS: + q *= math.sin(angle/2.0) / qlen + q[0] = math.cos(angle/2.0) + return q + + +def quaternion_matrix(quaternion): + """Return homogeneous rotation matrix from quaternion. + + >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0]) + >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0])) + True + >>> M = quaternion_matrix([1, 0, 0, 0]) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> M = quaternion_matrix([0, 1, 0, 0]) + >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1])) + True + + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + n = numpy.dot(q, q) + if n < _EPS: + return numpy.identity(4) + q *= math.sqrt(2.0 / n) + q = numpy.outer(q, q) + return numpy.array([ + [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0], + [q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0], + [q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0], + [0.0, 0.0, 0.0, 1.0]]) + + +def quaternion_from_matrix(matrix, isprecise=False): + """Return quaternion from rotation matrix. + + If isprecise is True, the input matrix is assumed to be a precise rotation + matrix and a faster algorithm is used. + + >>> q = quaternion_from_matrix(numpy.identity(4), True) + >>> numpy.allclose(q, [1, 0, 0, 0]) + True + >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1])) + >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0]) + True + >>> R = rotation_matrix(0.123, (1, 2, 3)) + >>> q = quaternion_from_matrix(R, True) + >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786]) + True + >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0], + ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611]) + True + >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0], + ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603]) + True + >>> R = random_rotation_matrix() + >>> q = quaternion_from_matrix(R) + >>> is_same_transform(R, quaternion_matrix(q)) + True + >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0) + >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4] + if isprecise: + q = numpy.empty((4, )) + t = numpy.trace(M) + if t > M[3, 3]: + q[0] = t + q[3] = M[1, 0] - M[0, 1] + q[2] = M[0, 2] - M[2, 0] + q[1] = M[2, 1] - M[1, 2] + else: + i, j, k = 0, 1, 2 + if M[1, 1] > M[0, 0]: + i, j, k = 1, 2, 0 + if M[2, 2] > M[i, i]: + i, j, k = 2, 0, 1 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q = q[[3, 0, 1, 2]] + q *= 0.5 / math.sqrt(t * M[3, 3]) + else: + m00 = M[0, 0] + m01 = M[0, 1] + m02 = M[0, 2] + m10 = M[1, 0] + m11 = M[1, 1] + m12 = M[1, 2] + m20 = M[2, 0] + m21 = M[2, 1] + m22 = M[2, 2] + # symmetric matrix K + K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0], + [m01+m10, m11-m00-m22, 0.0, 0.0], + [m02+m20, m12+m21, m22-m00-m11, 0.0], + [m21-m12, m02-m20, m10-m01, m00+m11+m22]]) + K /= 3.0 + # quaternion is eigenvector of K that corresponds to largest eigenvalue + w, V = numpy.linalg.eigh(K) + q = V[[3, 0, 1, 2], numpy.argmax(w)] + if q[0] < 0.0: + numpy.negative(q, q) + return q + + +def quaternion_multiply(quaternion1, quaternion0): + """Return multiplication of two quaternions. + + >>> q = quaternion_multiply([4, 1, -2, 3], [8, -5, 6, 7]) + >>> numpy.allclose(q, [28, -44, -14, 48]) + True + + """ + w0, x0, y0, z0 = quaternion0 + w1, x1, y1, z1 = quaternion1 + return numpy.array([ + -x1*x0 - y1*y0 - z1*z0 + w1*w0, + x1*w0 + y1*z0 - z1*y0 + w1*x0, + -x1*z0 + y1*w0 + z1*x0 + w1*y0, + x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64) + + +def quaternion_conjugate(quaternion): + """Return conjugate of quaternion. + + >>> q0 = random_quaternion() + >>> q1 = quaternion_conjugate(q0) + >>> q1[0] == q0[0] and all(q1[1:] == -q0[1:]) + True + + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + numpy.negative(q[1:], q[1:]) + return q + + +def quaternion_inverse(quaternion): + """Return inverse of quaternion. + + >>> q0 = random_quaternion() + >>> q1 = quaternion_inverse(q0) + >>> numpy.allclose(quaternion_multiply(q0, q1), [1, 0, 0, 0]) + True + + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + numpy.negative(q[1:], q[1:]) + return q / numpy.dot(q, q) + + +def quaternion_real(quaternion): + """Return real part of quaternion. + + >>> quaternion_real([3, 0, 1, 2]) + 3.0 + + """ + return float(quaternion[0]) + + +def quaternion_imag(quaternion): + """Return imaginary part of quaternion. + + >>> quaternion_imag([3, 0, 1, 2]) + array([ 0., 1., 2.]) + + """ + return numpy.array(quaternion[1:4], dtype=numpy.float64, copy=True) + + +def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True): + """Return spherical linear interpolation between two quaternions. + + >>> q0 = random_quaternion() + >>> q1 = random_quaternion() + >>> q = quaternion_slerp(q0, q1, 0) + >>> numpy.allclose(q, q0) + True + >>> q = quaternion_slerp(q0, q1, 1, 1) + >>> numpy.allclose(q, q1) + True + >>> q = quaternion_slerp(q0, q1, 0.5) + >>> angle = math.acos(numpy.dot(q0, q)) + >>> numpy.allclose(2, math.acos(numpy.dot(q0, q1)) / angle) or \ + numpy.allclose(2, math.acos(-numpy.dot(q0, q1)) / angle) + True + + """ + q0 = unit_vector(quat0[:4]) + q1 = unit_vector(quat1[:4]) + if fraction == 0.0: + return q0 + elif fraction == 1.0: + return q1 + d = numpy.dot(q0, q1) + if abs(abs(d) - 1.0) < _EPS: + return q0 + if shortestpath and d < 0.0: + # invert rotation + d = -d + numpy.negative(q1, q1) + angle = math.acos(d) + spin * math.pi + if abs(angle) < _EPS: + return q0 + isin = 1.0 / math.sin(angle) + q0 *= math.sin((1.0 - fraction) * angle) * isin + q1 *= math.sin(fraction * angle) * isin + q0 += q1 + return q0 + + +def random_quaternion(rand=None): + """Return uniform random unit quaternion. + + rand: array like or None + Three independent random variables that are uniformly distributed + between 0 and 1. + + >>> q = random_quaternion() + >>> numpy.allclose(1, vector_norm(q)) + True + >>> q = random_quaternion(numpy.random.random(3)) + >>> len(q.shape), q.shape[0]==4 + (1, True) + + """ + if rand is None: + rand = numpy.random.rand(3) + else: + assert len(rand) == 3 + r1 = numpy.sqrt(1.0 - rand[0]) + r2 = numpy.sqrt(rand[0]) + pi2 = math.pi * 2.0 + t1 = pi2 * rand[1] + t2 = pi2 * rand[2] + return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1, + numpy.cos(t1)*r1, numpy.sin(t2)*r2]) + + +def random_rotation_matrix(rand=None): + """Return uniform random rotation matrix. + + rand: array like + Three independent random variables that are uniformly distributed + between 0 and 1 for each returned quaternion. + + >>> R = random_rotation_matrix() + >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4)) + True + + """ + return quaternion_matrix(random_quaternion(rand)) + + +class Arcball(object): + """Virtual Trackball Control. + + >>> ball = Arcball() + >>> ball = Arcball(initial=numpy.identity(4)) + >>> ball.place([320, 320], 320) + >>> ball.down([500, 250]) + >>> ball.drag([475, 275]) + >>> R = ball.matrix() + >>> numpy.allclose(numpy.sum(R), 3.90583455) + True + >>> ball = Arcball(initial=[1, 0, 0, 0]) + >>> ball.place([320, 320], 320) + >>> ball.setaxes([1, 1, 0], [-1, 1, 0]) + >>> ball.constrain = True + >>> ball.down([400, 200]) + >>> ball.drag([200, 400]) + >>> R = ball.matrix() + >>> numpy.allclose(numpy.sum(R), 0.2055924) + True + >>> ball.next() + + """ + + def __init__(self, initial=None): + """Initialize virtual trackball control. + + initial : quaternion or rotation matrix + + """ + self._axis = None + self._axes = None + self._radius = 1.0 + self._center = [0.0, 0.0] + self._vdown = numpy.array([0.0, 0.0, 1.0]) + self._constrain = False + if initial is None: + self._qdown = numpy.array([1.0, 0.0, 0.0, 0.0]) + else: + initial = numpy.array(initial, dtype=numpy.float64) + if initial.shape == (4, 4): + self._qdown = quaternion_from_matrix(initial) + elif initial.shape == (4, ): + initial /= vector_norm(initial) + self._qdown = initial + else: + raise ValueError("initial not a quaternion or matrix") + self._qnow = self._qpre = self._qdown + + def place(self, center, radius): + """Place Arcball, e.g. when window size changes. + + center : sequence[2] + Window coordinates of trackball center. + radius : float + Radius of trackball in window coordinates. + + """ + self._radius = float(radius) + self._center[0] = center[0] + self._center[1] = center[1] + + def setaxes(self, *axes): + """Set axes to constrain rotations.""" + if axes is None: + self._axes = None + else: + self._axes = [unit_vector(axis) for axis in axes] + + @property + def constrain(self): + """Return state of constrain to axis mode.""" + return self._constrain + + @constrain.setter + def constrain(self, value): + """Set state of constrain to axis mode.""" + self._constrain = bool(value) + + def down(self, point): + """Set initial cursor window coordinates and pick constrain-axis.""" + self._vdown = arcball_map_to_sphere(point, self._center, self._radius) + self._qdown = self._qpre = self._qnow + if self._constrain and self._axes is not None: + self._axis = arcball_nearest_axis(self._vdown, self._axes) + self._vdown = arcball_constrain_to_axis(self._vdown, self._axis) + else: + self._axis = None + + def drag(self, point): + """Update current cursor window coordinates.""" + vnow = arcball_map_to_sphere(point, self._center, self._radius) + if self._axis is not None: + vnow = arcball_constrain_to_axis(vnow, self._axis) + self._qpre = self._qnow + t = numpy.cross(self._vdown, vnow) + if numpy.dot(t, t) < _EPS: + self._qnow = self._qdown + else: + q = [numpy.dot(self._vdown, vnow), t[0], t[1], t[2]] + self._qnow = quaternion_multiply(q, self._qdown) + + def next(self, acceleration=0.0): + """Continue rotation in direction of last drag.""" + q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False) + self._qpre, self._qnow = self._qnow, q + + def matrix(self): + """Return homogeneous rotation matrix.""" + return quaternion_matrix(self._qnow) + + +def arcball_map_to_sphere(point, center, radius): + """Return unit sphere coordinates from window coordinates.""" + v0 = (point[0] - center[0]) / radius + v1 = (center[1] - point[1]) / radius + n = v0*v0 + v1*v1 + if n > 1.0: + # position outside of sphere + n = math.sqrt(n) + return numpy.array([v0/n, v1/n, 0.0]) + else: + return numpy.array([v0, v1, math.sqrt(1.0 - n)]) + + +def arcball_constrain_to_axis(point, axis): + """Return sphere point perpendicular to axis.""" + v = numpy.array(point, dtype=numpy.float64, copy=True) + a = numpy.array(axis, dtype=numpy.float64, copy=True) + v -= a * numpy.dot(a, v) # on plane + n = vector_norm(v) + if n > _EPS: + if v[2] < 0.0: + numpy.negative(v, v) + v /= n + return v + if a[2] == 1.0: + return numpy.array([1.0, 0.0, 0.0]) + return unit_vector([-a[1], a[0], 0.0]) + + +def arcball_nearest_axis(point, axes): + """Return axis, which arc is nearest to point.""" + point = numpy.array(point, dtype=numpy.float64, copy=False) + nearest = None + mx = -1.0 + for axis in axes: + t = numpy.dot(arcball_constrain_to_axis(point, axis), point) + if t > mx: + nearest = axis + mx = t + return nearest + + +# epsilon for testing whether a number is close to zero +_EPS = numpy.finfo(float).eps * 4.0 + +# axis sequences for Euler angles +_NEXT_AXIS = [1, 2, 0, 1] + +# map axes strings to/from tuples of inner axis, parity, repetition, frame +_AXES2TUPLE = { + 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0), + 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0), + 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0), + 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0), + 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1), + 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1), + 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1), + 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)} + +_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) + + +def vector_norm(data, axis=None, out=None): + """Return length, i.e. Euclidean norm, of ndarray along axis. + + >>> v = numpy.random.random(3) + >>> n = vector_norm(v) + >>> numpy.allclose(n, numpy.linalg.norm(v)) + True + >>> v = numpy.random.rand(6, 5, 3) + >>> n = vector_norm(v, axis=-1) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2))) + True + >>> n = vector_norm(v, axis=1) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) + True + >>> v = numpy.random.rand(5, 4, 3) + >>> n = numpy.empty((5, 3)) + >>> vector_norm(v, axis=1, out=n) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) + True + >>> vector_norm([]) + 0.0 + >>> vector_norm([1]) + 1.0 + + """ + data = numpy.array(data, dtype=numpy.float64, copy=True) + if out is None: + if data.ndim == 1: + return math.sqrt(numpy.dot(data, data)) + data *= data + out = numpy.atleast_1d(numpy.sum(data, axis=axis)) + numpy.sqrt(out, out) + return out + else: + data *= data + numpy.sum(data, axis=axis, out=out) + numpy.sqrt(out, out) + + +def unit_vector(data, axis=None, out=None): + """Return ndarray normalized by length, i.e. Euclidean norm, along axis. + + >>> v0 = numpy.random.random(3) + >>> v1 = unit_vector(v0) + >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0)) + True + >>> v0 = numpy.random.rand(5, 4, 3) + >>> v1 = unit_vector(v0, axis=-1) + >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2) + >>> numpy.allclose(v1, v2) + True + >>> v1 = unit_vector(v0, axis=1) + >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1) + >>> numpy.allclose(v1, v2) + True + >>> v1 = numpy.empty((5, 4, 3)) + >>> unit_vector(v0, axis=1, out=v1) + >>> numpy.allclose(v1, v2) + True + >>> list(unit_vector([])) + [] + >>> list(unit_vector([1])) + [1.0] + + """ + if out is None: + data = numpy.array(data, dtype=numpy.float64, copy=True) + if data.ndim == 1: + data /= math.sqrt(numpy.dot(data, data)) + return data + else: + if out is not data: + out[:] = numpy.array(data, copy=False) + data = out + length = numpy.atleast_1d(numpy.sum(data*data, axis)) + numpy.sqrt(length, length) + if axis is not None: + length = numpy.expand_dims(length, axis) + data /= length + if out is None: + return data + + +def random_vector(size): + """Return array of random doubles in the half-open interval [0.0, 1.0). + + >>> v = random_vector(10000) + >>> numpy.all(v >= 0) and numpy.all(v < 1) + True + >>> v0 = random_vector(10) + >>> v1 = random_vector(10) + >>> numpy.any(v0 == v1) + False + + """ + return numpy.random.random(size) + + +def vector_product(v0, v1, axis=0): + """Return vector perpendicular to vectors. + + >>> v = vector_product([2, 0, 0], [0, 3, 0]) + >>> numpy.allclose(v, [0, 0, 6]) + True + >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]] + >>> v1 = [[3], [0], [0]] + >>> v = vector_product(v0, v1) + >>> numpy.allclose(v, [[0, 0, 0, 0], [0, 0, 6, 6], [0, -6, 0, -6]]) + True + >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]] + >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]] + >>> v = vector_product(v0, v1, axis=1) + >>> numpy.allclose(v, [[0, 0, 6], [0, -6, 0], [6, 0, 0], [0, -6, 6]]) + True + + """ + return numpy.cross(v0, v1, axis=axis) + + +def angle_between_vectors(v0, v1, directed=True, axis=0): + """Return angle between vectors. + + If directed is False, the input vectors are interpreted as undirected axes, + i.e. the maximum angle is pi/2. + + >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3]) + >>> numpy.allclose(a, math.pi) + True + >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3], directed=False) + >>> numpy.allclose(a, 0) + True + >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]] + >>> v1 = [[3], [0], [0]] + >>> a = angle_between_vectors(v0, v1) + >>> numpy.allclose(a, [0, 1.5708, 1.5708, 0.95532]) + True + >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]] + >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]] + >>> a = angle_between_vectors(v0, v1, axis=1) + >>> numpy.allclose(a, [1.5708, 1.5708, 1.5708, 0.95532]) + True + + """ + v0 = numpy.array(v0, dtype=numpy.float64, copy=False) + v1 = numpy.array(v1, dtype=numpy.float64, copy=False) + dot = numpy.sum(v0 * v1, axis=axis) + dot /= vector_norm(v0, axis=axis) * vector_norm(v1, axis=axis) + dot = numpy.clip(dot, -1.0, 1.0) + return numpy.arccos(dot if directed else numpy.fabs(dot)) + + +def inverse_matrix(matrix): + """Return inverse of square transformation matrix. + + >>> M0 = random_rotation_matrix() + >>> M1 = inverse_matrix(M0.T) + >>> numpy.allclose(M1, numpy.linalg.inv(M0.T)) + True + >>> for size in range(1, 7): + ... M0 = numpy.random.rand(size, size) + ... M1 = inverse_matrix(M0) + ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print(size) + + """ + return numpy.linalg.inv(matrix) + + +def concatenate_matrices(*matrices): + """Return concatenation of series of transformation matrices. + + >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5 + >>> numpy.allclose(M, concatenate_matrices(M)) + True + >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T)) + True + + """ + M = numpy.identity(4) + for i in matrices: + M = numpy.dot(M, i) + return M + + +def is_same_transform(matrix0, matrix1): + """Return True if two matrices perform same transformation. + + >>> is_same_transform(numpy.identity(4), numpy.identity(4)) + True + >>> is_same_transform(numpy.identity(4), random_rotation_matrix()) + False + + """ + matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True) + matrix0 /= matrix0[3, 3] + matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True) + matrix1 /= matrix1[3, 3] + return numpy.allclose(matrix0, matrix1) + + +def is_same_quaternion(q0, q1): + """Return True if two quaternions are equal.""" + q0 = numpy.array(q0) + q1 = numpy.array(q1) + return numpy.allclose(q0, q1) or numpy.allclose(q0, -q1) + + +def _import_module(name, package=None, warn=True, postfix='_py', ignore='_'): + """Try import all public attributes from module into global namespace. + + Existing attributes with name clashes are renamed with prefix. + Attributes starting with underscore are ignored by default. + + Return True on successful import. + + """ + import warnings + from importlib import import_module + try: + if not package: + module = import_module(name) + else: + module = import_module('.' + name, package=package) + except ImportError as err: + if warn: + warnings.warn(str(err)) + else: + for attr in dir(module): + if ignore and attr.startswith(ignore): + continue + if postfix: + if attr in globals(): + globals()[attr + postfix] = globals()[attr] + elif warn: + warnings.warn('no Python implementation of ' + attr) + globals()[attr] = getattr(module, attr) + return True + + +_import_module('_transformations', __package__, warn=False) + + +if __name__ == '__main__': + import doctest + import random # noqa: used in doctests + try: + numpy.set_printoptions(suppress=True, precision=5, legacy='1.13') + except TypeError: + numpy.set_printoptions(suppress=True, precision=5) + doctest.testmod() diff --git a/imcui/third_party/COTR/COTR/utils/constants.py b/imcui/third_party/COTR/COTR/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae0035c8a180a9035a272e0932be5108828771f --- /dev/null +++ b/imcui/third_party/COTR/COTR/utils/constants.py @@ -0,0 +1,3 @@ +DEFAULT_PRECISION = 'float32' +MAX_SIZE = 256 +VALID_NN_OVERLAPPING_THRESH = 0.1 diff --git a/imcui/third_party/COTR/COTR/utils/debug_utils.py b/imcui/third_party/COTR/COTR/utils/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74db4f44148fd784604053084e277f8841d0d0a4 --- /dev/null +++ b/imcui/third_party/COTR/COTR/utils/debug_utils.py @@ -0,0 +1,15 @@ +def embed_breakpoint(debug_info='', terminate=True): + print('\nyou are inside a break point') + if debug_info: + print('debug info: {0}'.format(debug_info)) + print('') + embedding = ('import IPython\n' + 'import matplotlib.pyplot as plt\n' + 'IPython.embed()\n' + ) + if terminate: + embedding += ( + 'assert 0, \'force termination\'\n' + ) + + return embedding diff --git a/imcui/third_party/COTR/COTR/utils/utils.py b/imcui/third_party/COTR/COTR/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3422188ad93dea94926259ec1088552b11ef31 --- /dev/null +++ b/imcui/third_party/COTR/COTR/utils/utils.py @@ -0,0 +1,271 @@ +import random +import smtplib +import ssl +from collections import namedtuple + +from COTR.utils import debug_utils + +import numpy as np +import torch +import cv2 +import matplotlib.pyplot as plt +import PIL + + +''' +ImagePatch: patch: patch content, np array or None + x: left bound in original resolution + y: upper bound in original resolution + w: width of patch + h: height of patch + ow: width of original resolution + oh: height of original resolution +''' +ImagePatch = namedtuple('ImagePatch', ['patch', 'x', 'y', 'w', 'h', 'ow', 'oh']) +Point3D = namedtuple("Point3D", ["id", "arr_idx", "image_ids"]) +Point2D = namedtuple("Point2D", ["id_3d", "xy"]) + + +class CropCamConfig(): + def __init__(self, x, y, w, h, out_w, out_h, orig_w, orig_h): + ''' + xy: left upper corner + ''' + # assert x > 0 and x < orig_w + # assert y > 0 and y < orig_h + # assert w < orig_w and h < orig_h + # assert x - w / 2 > 0 and x + w / 2 < orig_w + # assert y - h / 2 > 0 and y + h / 2 < orig_h + # assert h / w == out_h / out_w + self.x = x + self.y = y + self.w = w + self.h = h + self.out_w = out_w + self.out_h = out_h + self.orig_w = orig_w + self.orig_h = orig_h + + def __str__(self): + out = f'original image size(h,w): [{self.orig_h}, {self.orig_w}]\n' + out += f'crop at(x,y): [{self.x}, {self.y}]\n' + out += f'crop size(h,w): [{self.h}, {self.w}]\n' + out += f'resize crop to(h,w): [{self.out_h}, {self.out_w}]' + return out + + +def fix_randomness(seed=42): + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(seed) + np.random.seed(seed) + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +def float_image_resize(img, shape, interp=PIL.Image.BILINEAR): + missing_channel = False + if len(img.shape) == 2: + missing_channel = True + img = img[..., None] + layers = [] + img = img.transpose(2, 0, 1) + for l in img: + l = np.array(PIL.Image.fromarray(l).resize(shape[::-1], resample=interp)) + assert l.shape[:2] == shape + layers.append(l) + if missing_channel: + return np.stack(layers, axis=-1)[..., 0] + else: + return np.stack(layers, axis=-1) + + +def is_nan(x): + """ + get mask of nan values. + :param x: torch or numpy var. + :return: a N-D array of bool. True -> nan, False -> ok. + """ + return x != x + + +def has_nan(x) -> bool: + """ + check whether x contains nan. + :param x: torch or numpy var. + :return: single bool, True -> x containing nan, False -> ok. + """ + if x is None: + return False + return is_nan(x).any() + + +def confirm(question='OK to continue?'): + """ + Ask user to enter Y or N (case-insensitive). + :return: True if the answer is Y. + :rtype: bool + """ + answer = "" + while answer not in ["y", "n"]: + answer = input(question + ' [y/n] ').lower() + return answer == "y" + + +def print_notification(content_list, notification_type='NOTIFICATION'): + print('---------------------- {0} ----------------------'.format(notification_type)) + print() + for content in content_list: + print(content) + print() + print('----------------------------------------------------') + + +def torch_img_to_np_img(torch_img): + '''convert a torch image to matplotlib-able numpy image + torch use Channels x Height x Width + numpy use Height x Width x Channels + Arguments: + torch_img {[type]} -- [description] + ''' + assert isinstance(torch_img, torch.Tensor), 'cannot process data type: {0}'.format(type(torch_img)) + if len(torch_img.shape) == 4 and (torch_img.shape[1] == 3 or torch_img.shape[1] == 1): + return np.transpose(torch_img.detach().cpu().numpy(), (0, 2, 3, 1)) + if len(torch_img.shape) == 3 and (torch_img.shape[0] == 3 or torch_img.shape[0] == 1): + return np.transpose(torch_img.detach().cpu().numpy(), (1, 2, 0)) + elif len(torch_img.shape) == 2: + return torch_img.detach().cpu().numpy() + else: + raise ValueError('cannot process this image') + + +def np_img_to_torch_img(np_img): + """convert a numpy image to torch image + numpy use Height x Width x Channels + torch use Channels x Height x Width + + Arguments: + np_img {[type]} -- [description] + """ + assert isinstance(np_img, np.ndarray), 'cannot process data type: {0}'.format(type(np_img)) + if len(np_img.shape) == 4 and (np_img.shape[3] == 3 or np_img.shape[3] == 1): + return torch.from_numpy(np.transpose(np_img, (0, 3, 1, 2))) + if len(np_img.shape) == 3 and (np_img.shape[2] == 3 or np_img.shape[2] == 1): + return torch.from_numpy(np.transpose(np_img, (2, 0, 1))) + elif len(np_img.shape) == 2: + return torch.from_numpy(np_img) + else: + raise ValueError('cannot process this image with shape: {0}'.format(np_img.shape)) + + +def safe_load_weights(model, saved_weights): + try: + model.load_state_dict(saved_weights) + except RuntimeError: + try: + weights = saved_weights + weights = {k.replace('module.', ''): v for k, v in weights.items()} + model.load_state_dict(weights) + except RuntimeError: + try: + weights = saved_weights + weights = {'module.' + k: v for k, v in weights.items()} + model.load_state_dict(weights) + except RuntimeError: + try: + pretrained_dict = saved_weights + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if ((k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape))} + assert len(pretrained_dict) != 0 + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + non_match_keys = set(model.state_dict().keys()) - set(pretrained_dict.keys()) + notification = [] + notification += ['pretrained weights PARTIALLY loaded, following are missing:'] + notification += [str(non_match_keys)] + print_notification(notification, 'WARNING') + except Exception as e: + print(f'pretrained weights loading failed {e}') + exit() + print('weights safely loaded') + + +def visualize_corrs(img1, img2, corrs, mask=None): + if mask is None: + mask = np.ones(len(corrs)).astype(bool) + + scale1 = 1.0 + scale2 = 1.0 + if img1.shape[1] > img2.shape[1]: + scale2 = img1.shape[1] / img2.shape[1] + w = img1.shape[1] + else: + scale1 = img2.shape[1] / img1.shape[1] + w = img2.shape[1] + # Resize if too big + max_w = 400 + if w > max_w: + scale1 *= max_w / w + scale2 *= max_w / w + img1 = cv2.resize(img1, (0, 0), fx=scale1, fy=scale1) + img2 = cv2.resize(img2, (0, 0), fx=scale2, fy=scale2) + + x1, x2 = corrs[:, :2], corrs[:, 2:] + h1, w1 = img1.shape[:2] + h2, w2 = img2.shape[:2] + img = np.zeros((h1 + h2, max(w1, w2), 3), dtype=img1.dtype) + img[:h1, :w1] = img1 + img[h1:, :w2] = img2 + # Move keypoints to coordinates to image coordinates + x1 = x1 * scale1 + x2 = x2 * scale2 + # recompute the coordinates for the second image + x2p = x2 + np.array([[0, h1]]) + fig = plt.figure(frameon=False) + fig = plt.imshow(img) + + cols = [ + [0.0, 0.67, 0.0], + [0.9, 0.1, 0.1], + ] + lw = .5 + alpha = 1 + + # Draw outliers + _x1 = x1[~mask] + _x2p = x2p[~mask] + xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T + ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T + plt.plot( + xs, ys, + alpha=alpha, + linestyle="-", + linewidth=lw, + aa=False, + color=cols[1], + ) + + + # Draw Inliers + _x1 = x1[mask] + _x2p = x2p[mask] + xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T + ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T + plt.plot( + xs, ys, + alpha=alpha, + linestyle="-", + linewidth=lw, + aa=False, + color=cols[0], + ) + plt.scatter(xs, ys) + + fig.axes.get_xaxis().set_visible(False) + fig.axes.get_yaxis().set_visible(False) + ax = plt.gca() + ax.set_axis_off() + plt.show() diff --git a/imcui/third_party/COTR/demo_face.py b/imcui/third_party/COTR/demo_face.py new file mode 100644 index 0000000000000000000000000000000000000000..4975356d3246c484703d2d6cf60fe9e1a2e67be8 --- /dev/null +++ b/imcui/third_party/COTR/demo_face.py @@ -0,0 +1,69 @@ +''' +COTR demo for human face +We use an off-the-shelf face landmarks detector: https://github.com/1adrianb/face-alignment +''' +import argparse +import os +import time + +import cv2 +import numpy as np +import torch +import imageio +import matplotlib.pyplot as plt + +from COTR.utils import utils, debug_utils +from COTR.models import build_model +from COTR.options.options import * +from COTR.options.options_utils import * +from COTR.inference.inference_helper import triangulate_corr +from COTR.inference.sparse_engine import SparseEngine + +utils.fix_randomness(0) +torch.set_grad_enabled(False) + + +def main(opt): + model = build_model(opt) + model = model.cuda() + weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict'] + utils.safe_load_weights(model, weights) + model = model.eval() + + img_a = imageio.imread('./sample_data/imgs/face_1.png', pilmode='RGB') + img_b = imageio.imread('./sample_data/imgs/face_2.png', pilmode='RGB') + queries = np.load('./sample_data/face_landmarks.npy')[0] + + engine = SparseEngine(model, 32, mode='stretching') + corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=False) + + f, axarr = plt.subplots(1, 2) + axarr[0].imshow(img_a) + axarr[0].scatter(*queries.T, s=1) + axarr[0].title.set_text('Reference Face') + axarr[0].axis('off') + axarr[1].imshow(img_b) + axarr[1].scatter(*corrs[:, 2:].T, s=1) + axarr[1].title.set_text('Target Face') + axarr[1].axis('off') + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) + parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') + parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') + + opt = parser.parse_args() + opt.command = ' '.join(sys.argv) + + layer_2_channels = {'layer1': 256, + 'layer2': 512, + 'layer3': 1024, + 'layer4': 2048, } + opt.dim_feedforward = layer_2_channels[opt.layer] + if opt.load_weights: + opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') + print_opt(opt) + main(opt) diff --git a/imcui/third_party/COTR/demo_guided_matching.py b/imcui/third_party/COTR/demo_guided_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..1e68d922e0785144ef3087cee7c50453f484076c --- /dev/null +++ b/imcui/third_party/COTR/demo_guided_matching.py @@ -0,0 +1,85 @@ +''' +Feature-free COTR guided matching for keypoints. +We use DISK(https://github.com/cvlab-epfl/disk) keypoints location. +We apply RANSAC + F matrix to further prune outliers. +Note: This script doesn't use descriptors. +''' +import argparse +import os +import time + +import cv2 +import numpy as np +import torch +import imageio +from scipy.spatial import distance_matrix + +from COTR.utils import utils, debug_utils +from COTR.models import build_model +from COTR.options.options import * +from COTR.options.options_utils import * +from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine + +utils.fix_randomness(0) +torch.set_grad_enabled(False) + + +def main(opt): + model = build_model(opt) + model = model.cuda() + weights = torch.load(opt.load_weights_path)['model_state_dict'] + utils.safe_load_weights(model, weights) + model = model.eval() + + img_a = imageio.imread('./sample_data/imgs/21526113_4379776807.jpg') + img_b = imageio.imread('./sample_data/imgs/21126421_4537535153.jpg') + kp_a = np.load('./sample_data/21526113_4379776807.jpg.disk.kpts.npy') + kp_b = np.load('./sample_data/21126421_4537535153.jpg.disk.kpts.npy') + + if opt.faster_infer: + engine = FasterSparseEngine(model, 32, mode='tile') + else: + engine = SparseEngine(model, 32, mode='tile') + t0 = time.time() + corrs_a_b = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True) + corrs_b_a = engine.cotr_corr_multiscale(img_b, img_a, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_b.shape[0], queries_a=kp_b, force=True) + t1 = time.time() + print(f'COTR spent {t1-t0} seconds.') + inds_a_b = np.argmin(distance_matrix(corrs_a_b[:, 2:], kp_b), axis=1) + matched_a_b = np.stack([np.arange(kp_a.shape[0]), inds_a_b]).T + inds_b_a = np.argmin(distance_matrix(corrs_b_a[:, 2:], kp_a), axis=1) + matched_b_a = np.stack([np.arange(kp_b.shape[0]), inds_b_a]).T + + good = 0 + final_matches = [] + for m_ab in matched_a_b: + for m_ba in matched_b_a: + if (m_ab == m_ba[::-1]).all(): + good += 1 + final_matches.append(m_ab) + break + final_matches = np.array(final_matches) + final_corrs = np.concatenate([kp_a[final_matches[:, 0]], kp_b[final_matches[:, 1]]], axis=1) + _, mask = cv2.findFundamentalMat(final_corrs[:, :2], final_corrs[:, 2:], cv2.FM_RANSAC, ransacReprojThreshold=5, confidence=0.999999) + utils.visualize_corrs(img_a, img_b, final_corrs[np.where(mask[:, 0])]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) + parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') + parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') + parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference') + + opt = parser.parse_args() + opt.command = ' '.join(sys.argv) + + layer_2_channels = {'layer1': 256, + 'layer2': 512, + 'layer3': 1024, + 'layer4': 2048, } + opt.dim_feedforward = layer_2_channels[opt.layer] + if opt.load_weights: + opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') + print_opt(opt) + main(opt) diff --git a/imcui/third_party/COTR/demo_homography.py b/imcui/third_party/COTR/demo_homography.py new file mode 100644 index 0000000000000000000000000000000000000000..cf5962ec68494a4a8817ac46ece468c952a71cd7 --- /dev/null +++ b/imcui/third_party/COTR/demo_homography.py @@ -0,0 +1,84 @@ +''' +COTR demo for homography estimation +''' +import argparse +import os +import time + +import cv2 +import numpy as np +import torch +import imageio +import matplotlib.pyplot as plt + +from COTR.utils import utils, debug_utils +from COTR.models import build_model +from COTR.options.options import * +from COTR.options.options_utils import * +from COTR.inference.inference_helper import triangulate_corr +from COTR.inference.sparse_engine import SparseEngine + +utils.fix_randomness(0) +torch.set_grad_enabled(False) + + +def main(opt): + model = build_model(opt) + model = model.cuda() + weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict'] + utils.safe_load_weights(model, weights) + model = model.eval() + + img_a = imageio.imread('./sample_data/imgs/paint_1.JPG', pilmode='RGB') + img_b = imageio.imread('./sample_data/imgs/paint_2.jpg', pilmode='RGB') + rep_img = imageio.imread('./sample_data/imgs/Meisje_met_de_parel.jpg', pilmode='RGB') + rep_mask = np.ones(rep_img.shape[:2]) + lu_corner = [932, 1025] + ru_corner = [2469, 901] + lb_corner = [908, 2927] + rb_corner = [2436, 3080] + queries = np.array([lu_corner, ru_corner, lb_corner, rb_corner]).astype(np.float32) + rep_coord = np.array([[0, 0], [rep_img.shape[1], 0], [0, rep_img.shape[0]], [rep_img.shape[1], rep_img.shape[0]]]).astype(np.float32) + + engine = SparseEngine(model, 32, mode='stretching') + corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=True) + + T = cv2.getPerspectiveTransform(rep_coord, corrs[:, 2:].astype(np.float32)) + vmask = cv2.warpPerspective(rep_mask, T, (img_b.shape[1], img_b.shape[0])) > 0 + warped = cv2.warpPerspective(rep_img, T, (img_b.shape[1], img_b.shape[0])) + out = warped * vmask[..., None] + img_b * (~vmask[..., None]) + + f, axarr = plt.subplots(1, 4) + axarr[0].imshow(rep_img) + axarr[0].title.set_text('Virtual Paint') + axarr[0].axis('off') + axarr[1].imshow(img_a) + axarr[1].title.set_text('Annotated Frame') + axarr[1].axis('off') + axarr[2].imshow(img_b) + axarr[2].title.set_text('Target Frame') + axarr[2].axis('off') + axarr[3].imshow(out) + axarr[3].title.set_text('Overlay') + axarr[3].axis('off') + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) + parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') + parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') + + opt = parser.parse_args() + opt.command = ' '.join(sys.argv) + + layer_2_channels = {'layer1': 256, + 'layer2': 512, + 'layer3': 1024, + 'layer4': 2048, } + opt.dim_feedforward = layer_2_channels[opt.layer] + if opt.load_weights: + opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') + print_opt(opt) + main(opt) diff --git a/imcui/third_party/COTR/demo_reconstruction.py b/imcui/third_party/COTR/demo_reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..06262668446595b3dc368e416ad5373dc4b26873 --- /dev/null +++ b/imcui/third_party/COTR/demo_reconstruction.py @@ -0,0 +1,92 @@ +''' +COTR two view reconstruction with known extrinsic/intrinsic demo +''' +import argparse +import os +import time + +import numpy as np +import torch +import imageio +import open3d as o3d + +from COTR.utils import utils, debug_utils +from COTR.models import build_model +from COTR.options.options import * +from COTR.options.options_utils import * +from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine +from COTR.projector import pcd_projector + +utils.fix_randomness(0) +torch.set_grad_enabled(False) + + +def triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b): + A = center_a + a = dir_a / np.linalg.norm(dir_a, axis=1, keepdims=True) + B = center_b + b = dir_b / np.linalg.norm(dir_b, axis=1, keepdims=True) + c = B - A + D = A + a * ((-np.sum(a * b, axis=1) * np.sum(b * c, axis=1) + np.sum(a * c, axis=1) * np.sum(b * b, axis=1)) / (np.sum(a * a, axis=1) * np.sum(b * b, axis=1) - np.sum(a * b, axis=1) * np.sum(a * b, axis=1)))[..., None] + return D + + +def main(opt): + model = build_model(opt) + model = model.cuda() + weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict'] + utils.safe_load_weights(model, weights) + model = model.eval() + + img_a = imageio.imread('./sample_data/imgs/img_0.jpg', pilmode='RGB') + img_b = imageio.imread('./sample_data/imgs/img_1.jpg', pilmode='RGB') + + if opt.faster_infer: + engine = FasterSparseEngine(model, 32, mode='tile') + else: + engine = SparseEngine(model, 32, mode='tile') + t0 = time.time() + corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None) + t1 = time.time() + print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.') + + camera_a = np.load('./sample_data/camera_0.npy', allow_pickle=True).item() + camera_b = np.load('./sample_data/camera_1.npy', allow_pickle=True).item() + center_a = camera_a['cam_center'] + center_b = camera_b['cam_center'] + rays_a = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, :2], np.ones([corrs.shape[0], 1]) * 2, camera_a['intrinsic'], motion=camera_a['c2w']) + rays_b = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, 2:], np.ones([corrs.shape[0], 1]) * 2, camera_b['intrinsic'], motion=camera_b['c2w']) + dir_a = rays_a - center_a + dir_b = rays_b - center_b + center_a = np.array([center_a] * corrs.shape[0]) + center_b = np.array([center_b] * corrs.shape[0]) + points = triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b) + colors = (img_a[tuple(np.floor(corrs[:, :2]).astype(int)[:, ::-1].T)] / 255 + img_b[tuple(np.floor(corrs[:, 2:]).astype(int)[:, ::-1].T)] / 255) / 2 + colors = np.array(colors) + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors) + o3d.visualization.draw_geometries([pcd]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) + parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') + parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') + parser.add_argument('--max_corrs', type=int, default=2048, help='number of correspondences') + parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference') + + opt = parser.parse_args() + opt.command = ' '.join(sys.argv) + + layer_2_channels = {'layer1': 256, + 'layer2': 512, + 'layer3': 1024, + 'layer4': 2048, } + opt.dim_feedforward = layer_2_channels[opt.layer] + if opt.load_weights: + opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') + print_opt(opt) + main(opt) diff --git a/imcui/third_party/COTR/demo_single_pair.py b/imcui/third_party/COTR/demo_single_pair.py new file mode 100644 index 0000000000000000000000000000000000000000..babc99589542f342e3225f2c345e3aa05535a2f1 --- /dev/null +++ b/imcui/third_party/COTR/demo_single_pair.py @@ -0,0 +1,66 @@ +''' +COTR demo for a single image pair +''' +import argparse +import os +import time + +import cv2 +import numpy as np +import torch +import imageio +import matplotlib.pyplot as plt + +from COTR.utils import utils, debug_utils +from COTR.models import build_model +from COTR.options.options import * +from COTR.options.options_utils import * +from COTR.inference.inference_helper import triangulate_corr +from COTR.inference.sparse_engine import SparseEngine + +utils.fix_randomness(0) +torch.set_grad_enabled(False) + + +def main(opt): + model = build_model(opt) + model = model.cuda() + weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict'] + utils.safe_load_weights(model, weights) + model = model.eval() + + img_a = imageio.imread('./sample_data/imgs/cathedral_1.jpg', pilmode='RGB') + img_b = imageio.imread('./sample_data/imgs/cathedral_2.jpg', pilmode='RGB') + + engine = SparseEngine(model, 32, mode='tile') + t0 = time.time() + corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None) + t1 = time.time() + + utils.visualize_corrs(img_a, img_b, corrs) + print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.') + dense = triangulate_corr(corrs, img_a.shape, img_b.shape) + warped = cv2.remap(img_b, dense[..., 0].astype(np.float32), dense[..., 1].astype(np.float32), interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) + plt.imshow(warped / 255 * 0.5 + img_a / 255 * 0.5) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) + parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') + parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') + parser.add_argument('--max_corrs', type=int, default=100, help='number of correspondences') + + opt = parser.parse_args() + opt.command = ' '.join(sys.argv) + + layer_2_channels = {'layer1': 256, + 'layer2': 512, + 'layer3': 1024, + 'layer4': 2048, } + opt.dim_feedforward = layer_2_channels[opt.layer] + if opt.load_weights: + opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') + print_opt(opt) + main(opt) diff --git a/imcui/third_party/COTR/demo_wbs.py b/imcui/third_party/COTR/demo_wbs.py new file mode 100644 index 0000000000000000000000000000000000000000..1592ca506cdc0fa06c0c5290b741a3f60f6023ae --- /dev/null +++ b/imcui/third_party/COTR/demo_wbs.py @@ -0,0 +1,71 @@ +''' +Manually passing scale to COTR, skip the scale difference estimation. +''' +import argparse +import os +import time + +import cv2 +import numpy as np +import torch +import imageio +from scipy.spatial import distance_matrix +import matplotlib.pyplot as plt + +from COTR.utils import utils, debug_utils +from COTR.models import build_model +from COTR.options.options import * +from COTR.options.options_utils import * +from COTR.inference.sparse_engine import SparseEngine + +utils.fix_randomness(0) +torch.set_grad_enabled(False) + + +def main(opt): + model = build_model(opt) + model = model.cuda() + weights = torch.load(opt.load_weights_path)['model_state_dict'] + utils.safe_load_weights(model, weights) + model = model.eval() + + img_a = imageio.imread('./sample_data/imgs/petrzin_01.png') + img_b = imageio.imread('./sample_data/imgs/petrzin_02.png') + img_a_area = 1.0 + img_b_area = 1.0 + gt_corrs = np.loadtxt('./sample_data/petrzin_pts.txt') + kp_a = gt_corrs[:, :2] + kp_b = gt_corrs[:, 2:] + + engine = SparseEngine(model, 32, mode='tile') + t0 = time.time() + corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.75, 0.1, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True, areas=[img_a_area, img_b_area]) + t1 = time.time() + print(f'COTR spent {t1-t0} seconds.') + + utils.visualize_corrs(img_a, img_b, corrs) + plt.imshow(img_b) + plt.scatter(kp_b[:,0], kp_b[:,1]) + plt.scatter(corrs[:,2], corrs[:,3]) + plt.plot(np.stack([kp_b[:,0], corrs[:,2]], axis=1).T, np.stack([kp_b[:,1], corrs[:,3]], axis=1).T, color=[1,0,0]) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + set_COTR_arguments(parser) + parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') + parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') + + opt = parser.parse_args() + opt.command = ' '.join(sys.argv) + + layer_2_channels = {'layer1': 256, + 'layer2': 512, + 'layer3': 1024, + 'layer4': 2048, } + opt.dim_feedforward = layer_2_channels[opt.layer] + if opt.load_weights: + opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') + print_opt(opt) + main(opt) diff --git a/imcui/third_party/COTR/environment.yml b/imcui/third_party/COTR/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..bc9ffc910bf9809b616504348d9b369043dee5cb --- /dev/null +++ b/imcui/third_party/COTR/environment.yml @@ -0,0 +1,104 @@ +name: cotr_env +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2021.4.13=h06a4308_1 + - cairo=1.16.0=hf32fb01_1 + - certifi=2020.12.5=py37h06a4308_0 + - cudatoolkit=10.2.89=hfd86e86_1 + - cycler=0.10.0=py37_0 + - dbus=1.13.18=hb2f20db_0 + - decorator=5.0.6=pyhd3eb1b0_0 + - expat=2.3.0=h2531618_2 + - ffmpeg=4.0=hcdf2ecd_0 + - fontconfig=2.13.1=h6c09931_0 + - freeglut=3.0.0=hf484d3e_5 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.1=h36276a3_0 + - graphite2=1.3.14=h23475e2_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - harfbuzz=1.8.8=hffaf4a1_0 + - hdf5=1.10.2=hba1933b_1 + - icu=58.2=he6710b0_3 + - imageio=2.9.0=pyhd3eb1b0_0 + - intel-openmp=2021.2.0=h06a4308_610 + - ipython=7.22.0=py37hb070fc8_0 + - ipython_genutils=0.2.0=pyhd3eb1b0_1 + - jasper=2.0.14=h07fcdf6_1 + - jedi=0.17.0=py37_0 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py37h2531618_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libglu=9.0.0=hf484d3e_1 + - libopencv=3.4.2=hb342d67_1 + - libopus=1.3.1=h7b6447c_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_1 + - libuuid=1.0.3=h1bed415_2 + - libuv=1.40.0=h7b6447c_0 + - libvpx=1.7.0=h439df22_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=hb55368b_3 + - lz4-c=1.9.3=h2531618_0 + - matplotlib=3.3.4=py37h06a4308_0 + - matplotlib-base=3.3.4=py37h62a2d02_0 + - mkl=2020.2=256 + - mkl-service=2.3.0=py37he8ac12f_0 + - mkl_fft=1.3.0=py37h54f3939_0 + - mkl_random=1.1.1=py37h0573a6f_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=hff7bd54_1 + - numpy=1.19.2=py37h54aff64_0 + - numpy-base=1.19.2=py37hfa32c7d_0 + - olefile=0.46=py37_0 + - opencv=3.4.2=py37h6fd60c2_1 + - openssl=1.1.1k=h27cfd23_0 + - parso=0.8.2=pyhd3eb1b0_0 + - pcre=8.44=he6710b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pillow=8.2.0=py37he98fc37_0 + - pip=21.0.1=py37h06a4308_0 + - pixman=0.40.0=h7b6447c_0 + - prompt-toolkit=3.0.17=pyh06a4308_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - py-opencv=3.4.2=py37hb342d67_1 + - pygments=2.8.1=pyhd3eb1b0_0 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py37h05f1152_2 + - python=3.7.10=hdb3f193_0 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - scipy=1.2.1=py37h7c811a0_0 + - setuptools=52.0.0=py37h06a4308_0 + - sip=4.19.8=py37hf484d3e_0 + - six=1.15.0=py37h06a4308_0 + - sqlite=3.35.4=hdfb4753_0 + - tk=8.6.10=hbc83047_0 + - torchaudio=0.7.2=py37 + - torchvision=0.8.2=py37_cu102 + - tornado=6.1=py37h27cfd23_0 + - tqdm=4.59.0=pyhd3eb1b0_1 + - traitlets=5.0.5=pyhd3eb1b0_0 + - typing_extensions=3.7.4.3=pyha847dfd_0 + - vispy=0.5.3=py37hee6b756_0 + - wcwidth=0.2.5=py_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 + - pip: + - tables==3.6.1 diff --git a/imcui/third_party/COTR/scripts/prepare_megadepth_split.py b/imcui/third_party/COTR/scripts/prepare_megadepth_split.py new file mode 100644 index 0000000000000000000000000000000000000000..9da905d157cc88a167d91664689eccb8286107ee --- /dev/null +++ b/imcui/third_party/COTR/scripts/prepare_megadepth_split.py @@ -0,0 +1,36 @@ +import os +import random +import json + + +# read the json +valid_list_json_path = './megadepth_valid_list.json' +assert os.path.isfile(valid_list_json_path), 'Change to the valid list json' +with open(valid_list_json_path, 'r') as f: + all_list = json.load(f) + +# build scene - image dictionary +scene_img_dict = {} +for item in all_list: + if not item[:4] in scene_img_dict: + scene_img_dict[item[:4]] = [] + scene_img_dict[item[:4]].append(item) + +train_split = [] +val_split = [] +test_split = [] +for k in sorted(scene_img_dict.keys()): + if int(k) == 204: + val_split += scene_img_dict[k] + elif int(k) <= 240 and int(k) != 204: + train_split += scene_img_dict[k] + else: + test_split += scene_img_dict[k] + +# save split to json +with open('megadepth_train.json', 'w') as outfile: + json.dump(sorted(train_split), outfile, indent=4) +with open('megadepth_val.json', 'w') as outfile: + json.dump(sorted(val_split), outfile, indent=4) +with open('megadepth_test.json', 'w') as outfile: + json.dump(sorted(test_split), outfile, indent=4) diff --git a/imcui/third_party/COTR/scripts/prepare_megadepth_valid_list.py b/imcui/third_party/COTR/scripts/prepare_megadepth_valid_list.py new file mode 100644 index 0000000000000000000000000000000000000000..e08cf1fa75ecba25c46ee749d76d1356f7838dc9 --- /dev/null +++ b/imcui/third_party/COTR/scripts/prepare_megadepth_valid_list.py @@ -0,0 +1,41 @@ +import os +import json + +import tables +from tqdm import tqdm +import numpy as np + + +def read_all_imgs(base_dir): + all_imgs = [] + for cur, dirs, files in os.walk(base_dir): + if 'imgs' in cur: + all_imgs += [os.path.join(cur, f) for f in files] + all_imgs.sort() + return all_imgs + + +def filter_semantic_depth(imgs): + valid_imgs = [] + for item in tqdm(imgs): + f_name = os.path.splitext(os.path.basename(item))[0] + '.h5' + depth_dir = os.path.abspath(os.path.join(os.path.dirname(item), '../depths')) + depth_path = os.path.join(depth_dir, f_name) + depth_h5 = tables.open_file(depth_path, mode='r') + _depth = np.array(depth_h5.root.depth) + if _depth.min() >= 0: + prefix = os.path.abspath(os.path.join(item, '../../../../')) + '/' + rel_image_path = item.replace(prefix, '') + valid_imgs.append(rel_image_path) + depth_h5.close() + valid_imgs.sort() + return valid_imgs + + +if __name__ == "__main__": + MegaDepth_v1 = '/media/jiangwei/data_ssd/MegaDepth_v1/' + assert os.path.isdir(MegaDepth_v1), 'Change to your local path' + all_imgs = read_all_imgs(MegaDepth_v1) + valid_imgs = filter_semantic_depth(all_imgs) + with open('megadepth_valid_list.json', 'w') as outfile: + json.dump(valid_imgs, outfile, indent=4) diff --git a/imcui/third_party/COTR/scripts/prepare_nn_distance_mat.py b/imcui/third_party/COTR/scripts/prepare_nn_distance_mat.py new file mode 100644 index 0000000000000000000000000000000000000000..507eb3161d7c6c638dfaf99b560fa06c531cbec2 --- /dev/null +++ b/imcui/third_party/COTR/scripts/prepare_nn_distance_mat.py @@ -0,0 +1,145 @@ +''' +compute distance matrix for megadepth using ComputeCanada +''' + +import sys +sys.path.append('..') + + +import os +import argparse +import numpy as np +from joblib import Parallel, delayed +from tqdm import tqdm + +from COTR.options.options import * +from COTR.options.options_utils import * +from COTR.utils import debug_utils, utils, constants +from COTR.datasets import colmap_helper +from COTR.projector import pcd_projector +from COTR.global_configs import dataset_config + + +assert colmap_helper.COVISIBILITY_CHECK, 'Please enable COVISIBILITY_CHECK' +assert colmap_helper.LOAD_PCD, 'Please enable LOAD_PCD' + +OFFSET_THRESHOLD = 1.0 + + +def get_index_pairs(dist_mat, cells): + pairs = [] + for row in range(dist_mat.shape[0]): + for col in range(dist_mat.shape[0]): + if dist_mat[row][col] == -1: + pairs.append([row, col]) + if len(pairs) == cells: + return pairs + return pairs + + +def load_dist_mat(path, size=None): + if os.path.isfile(path): + dist_mat = np.load(path) + assert dist_mat.shape[0] == dist_mat.shape[1] + else: + dist_mat = np.ones([size, size], dtype=np.float32) * -1 + assert dist_mat.shape[0] == dist_mat.shape[1] + return dist_mat + + +def distance_between_two_caps(caps): + cap_1, cap_2 = caps + try: + if len(np.intersect1d(cap_1.point3d_id, cap_2.point3d_id)) == 0: + return 0.0 + pcd = cap_2.point_cloud_world + extrin_cap_1 = cap_1.cam_pose.world_to_camera[0:3, :] + intrin_cap_1 = cap_1.pinhole_cam.intrinsic_mat + size = cap_1.pinhole_cam.shape[:2] + reproj = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(pcd[:, 0:3], intrin_cap_1, extrin_cap_1, size, keep_z=True, crop=True, filter_neg=True, norm_coord=False) + reproj = pcd_projector.PointCloudProjector.pcd_2d_to_img_2d_np(reproj, size)[..., 0] + # 1. calculate the iou + query_mask = cap_1.depth_map > 0 + reproj_mask = reproj > 0 + intersection_mask = query_mask * reproj_mask + union_mask = query_mask | reproj_mask + if union_mask.sum() == 0: + return 0.0 + intersection_mask = (abs(cap_1.depth_map - reproj) * intersection_mask < OFFSET_THRESHOLD) * intersection_mask + ratio = intersection_mask.sum() / union_mask.sum() + if ratio == 0.0: + return 0.0 + return ratio + except Exception as e: + print(e) + return 0.0 + + +def fill_covisibility(scene, dist_mat): + for i in range(dist_mat.shape[0]): + nns = scene.get_covisible_caps(scene[i]) + covis_list = [scene.img_id_to_index_dict[cap.image_id] for cap in nns] + for j in range(dist_mat.shape[0]): + if j not in covis_list: + dist_mat[i][j] = 0 + return dist_mat + + +def main(opt): + # fast fail + try: + dist_mat = load_dist_mat(opt.out_path) + if dist_mat.min() >= 0.0: + print(f'{opt.out_path} is complete!') + exit() + else: + print('continue working') + except Exception as e: + print(e) + print('first time start working') + scene_dir = opt.scenes_name_list[0]['scene_dir'] + image_dir = opt.scenes_name_list[0]['image_dir'] + depth_dir = opt.scenes_name_list[0]['depth_dir'] + scene = colmap_helper.ColmapWithDepthAsciiReader.read_sfm_scene_given_valid_list_path(scene_dir, image_dir, depth_dir, dataset_config[opt.dataset_name]['valid_list_json'], opt.crop_cam) + size = len(scene.captures) + dist_mat = load_dist_mat(opt.out_path, size) + if opt.use_ram: + scene.read_data_to_ram(['depth']) + if dist_mat.max() == -1: + dist_mat = fill_covisibility(scene, dist_mat) + np.save(opt.out_path, dist_mat) + pairs = get_index_pairs(dist_mat, opt.cells) + in_pairs = [[scene[p[0]], scene[p[1]]] for p in pairs] + results = Parallel(n_jobs=opt.num_cpus)(delayed(distance_between_two_caps)(pair) for pair in tqdm(in_pairs, desc='calculating distance matrix', total=len(in_pairs))) + for i, p in enumerate(pairs): + r, c = p + dist_mat[r][c] = results[i] + np.save(opt.out_path, dist_mat) + print(f'finished from {pairs[0][0]}-{pairs[0][1]} -> {pairs[-1][0]}-{pairs[-1][1]}') + print(f'in total {len(pairs)} cells') + print(f'progress {(dist_mat >= 0).sum() / dist_mat.size}') + print(f'save at {opt.out_path}') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + set_general_arguments(parser) + parser.add_argument('--dataset_name', type=str, default='megadepth', help='dataset name') + parser.add_argument('--use_ram', type=str2bool, default=False, help='load image/depth to ram') + parser.add_argument('--info_level', type=str, default='rgbd', help='the information level of dataset') + parser.add_argument('--scene', type=str, default='0000', required=True, help='what scene want to use') + parser.add_argument('--seq', type=str, default='0', required=True, help='what seq want to use') + parser.add_argument('--crop_cam', choices=['no_crop', 'crop_center', 'crop_center_and_resize'], type=str, default='no_crop', help='crop the center of image to avoid changing aspect ratio, resize to make the operations batch-able.') + parser.add_argument('--cells', type=int, default=10000, help='the number of cells to be computed in this run') + parser.add_argument('--num_cpus', type=int, default=6, help='num of cores') + + opt = parser.parse_args() + opt.scenes_name_list = options_utils.build_scenes_name_list_from_opt(opt) + opt.out_dir = os.path.join(os.path.dirname(opt.scenes_name_list[0]['depth_dir']), 'dist_mat') + opt.out_path = os.path.join(opt.out_dir, 'dist_mat.npy') + os.makedirs(opt.out_dir, exist_ok=True) + if opt.confirm: + confirm_opt(opt) + else: + print_opt(opt) + main(opt) diff --git a/imcui/third_party/COTR/scripts/rectify_megadepth.py b/imcui/third_party/COTR/scripts/rectify_megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..257dee530d2d8fac96c994b07e9f36cba984fedb --- /dev/null +++ b/imcui/third_party/COTR/scripts/rectify_megadepth.py @@ -0,0 +1,299 @@ +''' +rectify the SfM model from SIMPLE_RADIAL to PINHOLE +''' +import os + +command_1 = 'colmap image_undistorter --image_path={0} --input_path={1} --output_path={2}' +command_2 = 'colmap model_converter --input_path={0} --output_path={1} --output_type=TXT' +command_3 = 'mv {0} {1}' +command_4 = 'python sort_images_txt.py --reference={0} --unordered={1} --save_to={2}' + +MegaDepth_v1_SfM = '/media/jiangwei/data_ssd/MegaDepth_v1_SfM/' +assert os.path.isdir(MegaDepth_v1_SfM), 'Change to your local path' +all_scenes = [ + '0000/sparse/manhattan/0', + '0000/sparse/manhattan/1', + '0001/sparse/manhattan/0', + '0002/sparse/manhattan/0', + '0003/sparse/manhattan/0', + '0004/sparse/manhattan/0', + '0004/sparse/manhattan/1', + '0004/sparse/manhattan/2', + '0005/sparse/manhattan/0', + '0005/sparse/manhattan/1', + '0007/sparse/manhattan/0', + '0007/sparse/manhattan/1', + '0008/sparse/manhattan/0', + '0011/sparse/manhattan/0', + '0012/sparse/manhattan/0', + '0013/sparse/manhattan/0', + '0015/sparse/manhattan/0', + '0015/sparse/manhattan/1', + '0016/sparse/manhattan/0', + '0017/sparse/manhattan/0', + '0019/sparse/manhattan/0', + '0019/sparse/manhattan/1', + '0020/sparse/manhattan/0', + '0020/sparse/manhattan/1', + '0021/sparse/manhattan/0', + '0022/sparse/manhattan/0', + '0023/sparse/manhattan/0', + '0023/sparse/manhattan/1', + '0024/sparse/manhattan/0', + '0025/sparse/manhattan/0', + '0025/sparse/manhattan/1', + '0026/sparse/manhattan/0', + '0027/sparse/manhattan/0', + '0032/sparse/manhattan/0', + '0032/sparse/manhattan/1', + '0033/sparse/manhattan/0', + '0034/sparse/manhattan/0', + '0035/sparse/manhattan/0', + '0036/sparse/manhattan/0', + '0037/sparse/manhattan/0', + '0039/sparse/manhattan/0', + '0041/sparse/manhattan/0', + '0041/sparse/manhattan/1', + '0042/sparse/manhattan/0', + '0043/sparse/manhattan/0', + '0044/sparse/manhattan/0', + '0046/sparse/manhattan/0', + '0046/sparse/manhattan/1', + '0046/sparse/manhattan/2', + '0047/sparse/manhattan/0', + '0048/sparse/manhattan/0', + '0049/sparse/manhattan/0', + '0050/sparse/manhattan/0', + '0056/sparse/manhattan/0', + '0057/sparse/manhattan/0', + '0058/sparse/manhattan/0', + '0058/sparse/manhattan/1', + '0060/sparse/manhattan/0', + '0061/sparse/manhattan/0', + '0062/sparse/manhattan/0', + '0062/sparse/manhattan/1', + '0063/sparse/manhattan/0', + '0063/sparse/manhattan/1', + '0063/sparse/manhattan/2', + '0063/sparse/manhattan/3', + '0064/sparse/manhattan/0', + '0065/sparse/manhattan/0', + '0067/sparse/manhattan/0', + '0070/sparse/manhattan/0', + '0071/sparse/manhattan/0', + '0071/sparse/manhattan/1', + '0076/sparse/manhattan/0', + '0078/sparse/manhattan/0', + '0080/sparse/manhattan/0', + '0083/sparse/manhattan/0', + '0086/sparse/manhattan/0', + '0087/sparse/manhattan/0', + '0087/sparse/manhattan/1', + '0090/sparse/manhattan/0', + '0092/sparse/manhattan/0', + '0092/sparse/manhattan/1', + '0094/sparse/manhattan/0', + '0095/sparse/manhattan/0', + '0095/sparse/manhattan/1', + '0095/sparse/manhattan/2', + '0098/sparse/manhattan/0', + '0099/sparse/manhattan/0', + '0100/sparse/manhattan/0', + '0101/sparse/manhattan/0', + '0102/sparse/manhattan/0', + '0103/sparse/manhattan/0', + '0104/sparse/manhattan/0', + '0104/sparse/manhattan/1', + '0105/sparse/manhattan/0', + '0107/sparse/manhattan/0', + '0115/sparse/manhattan/0', + '0117/sparse/manhattan/0', + '0117/sparse/manhattan/1', + '0117/sparse/manhattan/2', + '0121/sparse/manhattan/0', + '0121/sparse/manhattan/1', + '0122/sparse/manhattan/0', + '0129/sparse/manhattan/0', + '0130/sparse/manhattan/0', + '0130/sparse/manhattan/1', + '0130/sparse/manhattan/2', + '0133/sparse/manhattan/0', + '0133/sparse/manhattan/1', + '0137/sparse/manhattan/0', + '0137/sparse/manhattan/1', + '0137/sparse/manhattan/2', + '0141/sparse/manhattan/0', + '0143/sparse/manhattan/0', + '0147/sparse/manhattan/0', + '0147/sparse/manhattan/1', + '0148/sparse/manhattan/0', + '0148/sparse/manhattan/1', + '0149/sparse/manhattan/0', + '0150/sparse/manhattan/0', + '0151/sparse/manhattan/0', + '0156/sparse/manhattan/0', + '0160/sparse/manhattan/0', + '0160/sparse/manhattan/1', + '0160/sparse/manhattan/2', + '0162/sparse/manhattan/0', + '0162/sparse/manhattan/1', + '0168/sparse/manhattan/0', + '0175/sparse/manhattan/0', + '0176/sparse/manhattan/0', + '0176/sparse/manhattan/1', + '0176/sparse/manhattan/2', + '0177/sparse/manhattan/0', + '0178/sparse/manhattan/0', + '0178/sparse/manhattan/1', + '0181/sparse/manhattan/0', + '0183/sparse/manhattan/0', + '0185/sparse/manhattan/0', + '0186/sparse/manhattan/0', + '0189/sparse/manhattan/0', + '0190/sparse/manhattan/0', + '0197/sparse/manhattan/0', + '0200/sparse/manhattan/0', + '0200/sparse/manhattan/1', + '0204/sparse/manhattan/0', + '0204/sparse/manhattan/1', + '0205/sparse/manhattan/0', + '0205/sparse/manhattan/1', + '0209/sparse/manhattan/1', + '0212/sparse/manhattan/0', + '0212/sparse/manhattan/1', + '0214/sparse/manhattan/0', + '0214/sparse/manhattan/1', + '0217/sparse/manhattan/0', + '0223/sparse/manhattan/0', + '0223/sparse/manhattan/1', + '0223/sparse/manhattan/2', + '0224/sparse/manhattan/0', + '0224/sparse/manhattan/1', + '0229/sparse/manhattan/0', + '0231/sparse/manhattan/0', + '0235/sparse/manhattan/0', + '0237/sparse/manhattan/0', + '0238/sparse/manhattan/0', + '0240/sparse/manhattan/0', + '0243/sparse/manhattan/0', + '0252/sparse/manhattan/0', + '0257/sparse/manhattan/0', + '0258/sparse/manhattan/0', + '0265/sparse/manhattan/0', + '0265/sparse/manhattan/1', + '0269/sparse/manhattan/0', + '0269/sparse/manhattan/1', + '0269/sparse/manhattan/2', + '0271/sparse/manhattan/0', + '0275/sparse/manhattan/0', + '0277/sparse/manhattan/0', + '0277/sparse/manhattan/1', + '0281/sparse/manhattan/0', + '0285/sparse/manhattan/0', + '0286/sparse/manhattan/0', + '0286/sparse/manhattan/1', + '0290/sparse/manhattan/0', + '0290/sparse/manhattan/1', + '0294/sparse/manhattan/0', + '0299/sparse/manhattan/0', + '0303/sparse/manhattan/0', + '0306/sparse/manhattan/0', + '0307/sparse/manhattan/0', + '0312/sparse/manhattan/0', + '0312/sparse/manhattan/1', + '0323/sparse/manhattan/0', + '0326/sparse/manhattan/0', + '0327/sparse/manhattan/0', + '0327/sparse/manhattan/1', + '0327/sparse/manhattan/2', + '0331/sparse/manhattan/0', + '0335/sparse/manhattan/0', + '0335/sparse/manhattan/1', + '0341/sparse/manhattan/0', + '0341/sparse/manhattan/1', + '0348/sparse/manhattan/0', + '0349/sparse/manhattan/0', + '0349/sparse/manhattan/1', + '0360/sparse/manhattan/0', + '0360/sparse/manhattan/1', + '0360/sparse/manhattan/2', + '0366/sparse/manhattan/0', + '0377/sparse/manhattan/0', + '0380/sparse/manhattan/0', + '0387/sparse/manhattan/0', + '0389/sparse/manhattan/0', + '0389/sparse/manhattan/1', + '0394/sparse/manhattan/0', + '0394/sparse/manhattan/1', + '0402/sparse/manhattan/0', + '0402/sparse/manhattan/1', + '0406/sparse/manhattan/0', + '0407/sparse/manhattan/0', + '0411/sparse/manhattan/0', + '0411/sparse/manhattan/1', + '0412/sparse/manhattan/0', + '0412/sparse/manhattan/1', + '0412/sparse/manhattan/2', + '0430/sparse/manhattan/0', + '0430/sparse/manhattan/1', + '0430/sparse/manhattan/2', + '0443/sparse/manhattan/0', + '0446/sparse/manhattan/0', + '0455/sparse/manhattan/0', + '0472/sparse/manhattan/0', + '0472/sparse/manhattan/1', + '0474/sparse/manhattan/0', + '0474/sparse/manhattan/1', + '0474/sparse/manhattan/2', + '0476/sparse/manhattan/0', + '0476/sparse/manhattan/1', + '0476/sparse/manhattan/2', + '0478/sparse/manhattan/0', + '0478/sparse/manhattan/1', + '0482/sparse/manhattan/0', + '0493/sparse/manhattan/0', + '0493/sparse/manhattan/1', + '0494/sparse/manhattan/1', + '0496/sparse/manhattan/0', + '0505/sparse/manhattan/0', + '0559/sparse/manhattan/0', + '0733/sparse/manhattan/0', + '0733/sparse/manhattan/1', + '0768/sparse/manhattan/0', + '0860/sparse/manhattan/0', + '0860/sparse/manhattan/1', + '1001/sparse/manhattan/0', + '1017/sparse/manhattan/0', + '1589/sparse/manhattan/0', + '3346/sparse/manhattan/0', + '4541/sparse/manhattan/0', + '5000/sparse/manhattan/0', + '5001/sparse/manhattan/0', + '5002/sparse/manhattan/0', + '5003/sparse/manhattan/0', + '5004/sparse/manhattan/0', + '5005/sparse/manhattan/0', + '5006/sparse/manhattan/0', + '5007/sparse/manhattan/0', + '5008/sparse/manhattan/0', + '5009/sparse/manhattan/0', + '5010/sparse/manhattan/0', + '5011/sparse/manhattan/0', + '5012/sparse/manhattan/0', + '5013/sparse/manhattan/0', + '5014/sparse/manhattan/0', + '5015/sparse/manhattan/0', + '5016/sparse/manhattan/0', + '5017/sparse/manhattan/0', + '5018/sparse/manhattan/0', +] + +with open('rectify.sh', "w") as fid: + for s in all_scenes: + s = os.path.join(MegaDepth_v1_SfM, s) + new_dir = s + '_rectified' + img_dir = s[:s.find('sparse')] + 'images' + fid.write(command_1.format(img_dir, s, new_dir) + '\n') + fid.write(command_2.format(new_dir + '/sparse', new_dir + '/sparse') + '\n') + fid.write(command_3.format(new_dir + '/sparse/images.txt', new_dir + '/sparse/unorder_images.txt') + '\n') + fid.write(command_4.format(s + '/images.txt', new_dir + '/sparse/unorder_images.txt', new_dir + '/sparse/images.txt') + '\n') diff --git a/imcui/third_party/COTR/scripts/sort_images_txt.py b/imcui/third_party/COTR/scripts/sort_images_txt.py new file mode 100644 index 0000000000000000000000000000000000000000..aa52bddc294a62610aa657b01a8ec366ce9f3267 --- /dev/null +++ b/imcui/third_party/COTR/scripts/sort_images_txt.py @@ -0,0 +1,78 @@ +import sys +assert sys.version_info >= (3, 7), 'ordered dict is required' +import os +import argparse +import re + +from tqdm import tqdm + + +def read_images_meta(images_txt_path): + images_meta = {} + with open(images_txt_path, "r") as fid: + line = fid.readline() + assert line == '# Image list with two lines of data per image:\n' + line = fid.readline() + assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' + line = fid.readline() + assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n' + line = fid.readline() + assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line) + num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line) + num_images = int(num_images) + mean_ob_per_img = float(mean_ob_per_img) + + for _ in tqdm(range(num_images), desc='reading images meta'): + l = fid.readline() + elems = l.split() + image_id = int(elems[0]) + l2 = fid.readline() + images_meta[image_id] = [l, l2] + return images_meta + + +def read_header(images_txt_path): + header = [] + with open(images_txt_path, "r") as fid: + line = fid.readline() + assert line == '# Image list with two lines of data per image:\n' + header.append(line) + line = fid.readline() + assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' + header.append(line) + line = fid.readline() + assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n' + header.append(line) + line = fid.readline() + assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line) + header.append(line) + return header + + +def export_images_txt(save_to, header, content): + assert not os.path.isfile(save_to), 'you are overriding existing files' + with open(save_to, "w") as fid: + for l in header: + fid.write(l) + for k, item in content.items(): + for l in item: + fid.write(l) + + +def main(opt): + reference = read_images_meta(opt.reference_images_txt) + unordered = read_images_meta(opt.unordered_images_txt) + ordered = {} + for k in reference.keys(): + ordered[k] = unordered[k] + header = read_header(opt.unordered_images_txt) + export_images_txt(opt.save_to, header, ordered) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--reference_images_txt', type=str, default=None, required=True) + parser.add_argument('--unordered_images_txt', type=str, default=None, required=True) + parser.add_argument('--save_to', type=str, default=None, required=True) + opt = parser.parse_args() + main(opt) diff --git a/imcui/third_party/COTR/train_cotr.py b/imcui/third_party/COTR/train_cotr.py new file mode 100644 index 0000000000000000000000000000000000000000..dd06fda41c7c97edaac784195895f8a2b36d07f3 --- /dev/null +++ b/imcui/third_party/COTR/train_cotr.py @@ -0,0 +1,149 @@ +import argparse +import subprocess +import pprint + +import numpy as np +import torch +# import torch.multiprocessing +# torch.multiprocessing.set_sharing_strategy('file_system') +from torch.utils.data import DataLoader + +from COTR.models import build_model +from COTR.utils import debug_utils, utils +from COTR.datasets import cotr_dataset +from COTR.trainers.cotr_trainer import COTRTrainer +from COTR.global_configs import general_config +from COTR.options.options import * +from COTR.options.options_utils import * + + +utils.fix_randomness(0) + + +def train(opt): + pprint.pprint(dict(os.environ), width=1) + result = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE) + print(result.stdout.read().decode()) + device = torch.cuda.current_device() + print(f'can see {torch.cuda.device_count()} gpus') + print(f'current using gpu at {device} -- {torch.cuda.get_device_name(device)}') + # dummy = torch.rand(3758725612).to(device) + # del dummy + torch.cuda.empty_cache() + model = build_model(opt) + model = model.to(device) + if opt.enable_zoom: + train_dset = cotr_dataset.COTRZoomDataset(opt, 'train') + val_dset = cotr_dataset.COTRZoomDataset(opt, 'val') + else: + train_dset = cotr_dataset.COTRDataset(opt, 'train') + val_dset = cotr_dataset.COTRDataset(opt, 'val') + + train_loader = DataLoader(train_dset, batch_size=opt.batch_size, + shuffle=opt.shuffle_data, num_workers=opt.workers, + worker_init_fn=utils.worker_init_fn, pin_memory=True) + val_loader = DataLoader(val_dset, batch_size=opt.batch_size, + shuffle=opt.shuffle_data, num_workers=opt.workers, + drop_last=True, worker_init_fn=utils.worker_init_fn, pin_memory=True) + + optim_list = [{"params": model.transformer.parameters(), "lr": opt.learning_rate}, + {"params": model.corr_embed.parameters(), "lr": opt.learning_rate}, + {"params": model.query_proj.parameters(), "lr": opt.learning_rate}, + {"params": model.input_proj.parameters(), "lr": opt.learning_rate}, + ] + if opt.lr_backbone > 0: + optim_list.append({"params": model.backbone.parameters(), "lr": opt.lr_backbone}) + + optim = torch.optim.Adam(optim_list) + trainer = COTRTrainer(opt, model, optim, None, train_loader, val_loader) + trainer.train() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + set_general_arguments(parser) + set_dataset_arguments(parser) + set_nn_arguments(parser) + set_COTR_arguments(parser) + parser.add_argument('--num_kp', type=int, + default=100) + parser.add_argument('--kp_pool', type=int, + default=100) + parser.add_argument('--enable_zoom', type=str2bool, + default=False) + parser.add_argument('--zoom_start', type=float, + default=1.0) + parser.add_argument('--zoom_end', type=float, + default=0.1) + parser.add_argument('--zoom_levels', type=int, + default=10) + parser.add_argument('--zoom_jitter', type=float, + default=0.5) + + parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') + parser.add_argument('--tb_dir', type=str, default=general_config['tb_out'], help='tensorboard runs directory') + + parser.add_argument('--learning_rate', type=float, + default=1e-4, help='learning rate') + parser.add_argument('--lr_backbone', type=float, + default=1e-5, help='backbone learning rate') + parser.add_argument('--batch_size', type=int, + default=32, help='batch size for training') + parser.add_argument('--cycle_consis', type=str2bool, default=True, + help='cycle consistency') + parser.add_argument('--bidirectional', type=str2bool, default=True, + help='left2right and right2left') + parser.add_argument('--max_iter', type=int, + default=200000, help='total training iterations') + parser.add_argument('--valid_iter', type=int, + default=1000, help='iterval of validation') + parser.add_argument('--resume', type=str2bool, default=False, + help='resume training with same model name') + parser.add_argument('--cc_resume', type=str2bool, default=False, + help='resume from last run if possible') + parser.add_argument('--need_rotation', type=str2bool, default=False, + help='rotation augmentation') + parser.add_argument('--max_rotation', type=float, default=0, + help='max rotation for data augmentation') + parser.add_argument('--rotation_chance', type=float, default=0, + help='the probability of being rotated') + parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') + parser.add_argument('--suffix', type=str, default='', help='model suffix') + + opt = parser.parse_args() + opt.command = ' '.join(sys.argv) + + layer_2_channels = {'layer1': 256, + 'layer2': 512, + 'layer3': 1024, + 'layer4': 2048, } + opt.dim_feedforward = layer_2_channels[opt.layer] + opt.num_queries = opt.num_kp + + opt.name = get_compact_naming_cotr(opt) + opt.out = os.path.join(opt.out_dir, opt.name) + opt.tb_out = os.path.join(opt.tb_dir, opt.name) + + if opt.cc_resume: + if os.path.isfile(os.path.join(opt.out, 'checkpoint.pth.tar')): + print('resuming from last run') + opt.load_weights = None + opt.resume = True + else: + opt.resume = False + assert (bool(opt.load_weights) and opt.resume) == False + if opt.load_weights: + opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') + if opt.resume: + opt.load_weights_path = os.path.join(opt.out, 'checkpoint.pth.tar') + + opt.scenes_name_list = build_scenes_name_list_from_opt(opt) + + if opt.confirm: + confirm_opt(opt) + else: + print_opt(opt) + + save_opt(opt) + train(opt) diff --git a/imcui/third_party/DKM/demo/demo_fundamental.py b/imcui/third_party/DKM/demo/demo_fundamental.py new file mode 100644 index 0000000000000000000000000000000000000000..e19766d5d3ce1abf0d18483cbbce71b2696983be --- /dev/null +++ b/imcui/third_party/DKM/demo/demo_fundamental.py @@ -0,0 +1,37 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +from dkm.utils.utils import tensor_to_pil +import cv2 +from dkm import DKMv3_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + + # Create model + dkm_model = DKMv3_outdoor(device=device) + + + W_A, H_A = Image.open(im1_path).size + W_B, H_B = Image.open(im2_path).size + + # Match + warp, certainty = dkm_model.match(im1_path, im2_path, device=device) + # Sample matches for estimation + matches, certainty = dkm_model.sample(warp, certainty) + kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + F, mask = cv2.findFundamentalMat( + kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 + ) + # TODO: some better visualization \ No newline at end of file diff --git a/imcui/third_party/DKM/demo/demo_match.py b/imcui/third_party/DKM/demo/demo_match.py new file mode 100644 index 0000000000000000000000000000000000000000..fb901894d8654a884819162d3b9bb8094529e034 --- /dev/null +++ b/imcui/third_party/DKM/demo/demo_match.py @@ -0,0 +1,48 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +from dkm.utils.utils import tensor_to_pil + +from dkm import DKMv3_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + # Create model + dkm_model = DKMv3_outdoor(device=device) + + H, W = 864, 1152 + + im1 = Image.open(im1_path).resize((W, H)) + im2 = Image.open(im2_path).resize((W, H)) + + # Match + warp, certainty = dkm_model.match(im1_path, im2_path, device=device) + # Sampling not needed, but can be done with model.sample(warp, certainty) + dkm_model.sample(warp, certainty) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + im2_transfer_rgb = F.grid_sample( + x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + im1_transfer_rgb = F.grid_sample( + x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + tensor_to_pil(vis_im, unnormalize=False).save(save_path) diff --git a/imcui/third_party/DKM/dkm/__init__.py b/imcui/third_party/DKM/dkm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b47632780acc7762bcccc348e2025fe99f3726 --- /dev/null +++ b/imcui/third_party/DKM/dkm/__init__.py @@ -0,0 +1,4 @@ +from .models import ( + DKMv3_outdoor, + DKMv3_indoor, + ) diff --git a/imcui/third_party/DKM/dkm/benchmarks/__init__.py b/imcui/third_party/DKM/dkm/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57643fd314a2301138aecdc804a5877d0ce9274e --- /dev/null +++ b/imcui/third_party/DKM/dkm/benchmarks/__init__.py @@ -0,0 +1,4 @@ +from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark +from .scannet_benchmark import ScanNetBenchmark +from .megadepth1500_benchmark import Megadepth1500Benchmark +from .megadepth_dense_benchmark import MegadepthDenseBenchmark diff --git a/imcui/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..079622fdaf77c75aeadd675629f2512c45d04c2d --- /dev/null +++ b/imcui/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py @@ -0,0 +1,100 @@ +from PIL import Image +import numpy as np + +import os + +import torch +from tqdm import tqdm + +from dkm.utils import * + + +class HpatchesDenseBenchmark: + """WARNING: HPATCHES grid goes from [0,n-1] instead of [0.5,n-0.5]""" + + def __init__(self, dataset_path) -> None: + seqs_dir = "hpatches-sequences-release" + self.seqs_path = os.path.join(dataset_path, seqs_dir) + self.seq_names = sorted(os.listdir(self.seqs_path)) + + def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup): + # Get matches in output format on the grid [0, n] where the center of the top-left coordinate is [0.5, 0.5] + offset = ( + 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] + ) + query_coords = ( + torch.stack( + ( + wq * (query_coords[..., 0] + 1) / 2, + hq * (query_coords[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + query_to_support = ( + torch.stack( + ( + wsup * (query_to_support[..., 0] + 1) / 2, + hsup * (query_to_support[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + return query_coords, query_to_support + + def inside_image(self, x, w, h): + return torch.logical_and( + x[:, 0] < (w - 1), + torch.logical_and(x[:, 1] < (h - 1), (x > 0).prod(dim=-1)), + ) + + def benchmark(self, model): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if use_cuda else "cpu") + aepes = [] + pcks = [] + for seq_idx, seq_name in tqdm( + enumerate(self.seq_names), total=len(self.seq_names) + ): + if seq_name[0] == "i": + continue + im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm") + im1 = Image.open(im1_path) + w1, h1 = im1.size + for im_idx in range(2, 7): + im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm") + im2 = Image.open(im2_path) + w2, h2 = im2.size + matches, certainty = model.match(im2, im1, do_pred_in_og_res=True) + matches, certainty = matches.reshape(-1, 4), certainty.reshape(-1) + inv_homography = torch.from_numpy( + np.loadtxt( + os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) + ) + ).to(device) + homography = torch.linalg.inv(inv_homography) + pos_a, pos_b = self.convert_coordinates( + matches[:, :2], matches[:, 2:], w2, h2, w1, h1 + ) + pos_a, pos_b = pos_a.double(), pos_b.double() + pos_a_h = torch.cat( + [pos_a, torch.ones([pos_a.shape[0], 1], device=device)], dim=1 + ) + pos_b_proj_h = (homography @ pos_a_h.t()).t() + pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:] + mask = self.inside_image(pos_b_proj, w1, h1) + residual = pos_b - pos_b_proj + dist = (residual**2).sum(dim=1).sqrt()[mask] + aepes.append(torch.mean(dist).item()) + pck1 = (dist < 1.0).float().mean().item() + pck3 = (dist < 3.0).float().mean().item() + pck5 = (dist < 5.0).float().mean().item() + pcks.append([pck1, pck3, pck5]) + m_pcks = np.mean(np.array(pcks), axis=0) + return { + "hp_pck1": m_pcks[0], + "hp_pck3": m_pcks[1], + "hp_pck5": m_pcks[2], + } diff --git a/imcui/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..781a291f0c358cbd435790b0a639f2a2510145b2 --- /dev/null +++ b/imcui/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py @@ -0,0 +1,119 @@ +import pickle +import h5py +import numpy as np +import torch +from dkm.utils import * +from PIL import Image +from tqdm import tqdm + + +class Yfcc100mBenchmark: + def __init__(self, data_root="data/yfcc100m_test") -> None: + self.scenes = [ + "buckingham_palace", + "notre_dame_front_facade", + "reichstag", + "sacre_coeur", + ] + self.data_root = data_root + + def benchmark(self, model, r=2): + model.train(False) + with torch.no_grad(): + data_root = self.data_root + meta_info = open( + f"{data_root}/yfcc_test_pairs_with_gt.txt", "r" + ).readlines() + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + for scene_ind in range(len(self.scenes)): + scene = self.scenes[scene_ind] + pairs = np.array( + pickle.load( + open(f"{data_root}/pairs/{scene}-te-1000-pairs.pkl", "rb") + ) + ) + scene_dir = f"{data_root}/yfcc100m/{scene}/test/" + calibs = open(scene_dir + "calibration.txt", "r").read().split("\n") + images = open(scene_dir + "images.txt", "r").read().split("\n") + pair_inds = np.random.choice( + range(len(pairs)), size=len(pairs), replace=False + ) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind] + params = meta_info[1000 * scene_ind + pairind].split() + rot1, rot2 = int(params[2]), int(params[3]) + calib1 = h5py.File(scene_dir + calibs[idx1], "r") + K1, R1, t1, _, _ = get_pose(calib1) + calib2 = h5py.File(scene_dir + calibs[idx2], "r") + K2, R2, t2, _, _ = get_pose(calib2) + + R, t = compute_relative_pose(R1, t1, R2, t2) + im1 = images[idx1] + im2 = images[idx2] + im1 = Image.open(scene_dir + im1).rotate(rot1 * 90, expand=True) + w1, h1 = im1.size + im2 = Image.open(scene_dir + im2).rotate(rot2 * 90, expand=True) + w2, h2 = im2.size + K1 = rotate_intrinsic(K1, rot1) + K2 = rotate_intrinsic(K2, rot2) + + dense_matches, dense_certainty = model.match(im1, im2) + dense_certainty = dense_certainty ** (1 / r) + sparse_matches, sparse_confidence = model.sample( + dense_matches, dense_certainty, 10000 + ) + scale1 = 480 / min(w1, h1) + scale2 = 480 / min(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1 = K1 * scale1 + K2 = K2 * scale2 + + kpts1 = sparse_matches[:, :2] + kpts1 = np.stack( + (w1 * kpts1[:, 0] / 2, h1 * kpts1[:, 1] / 2), axis=-1 + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = np.stack( + (w2 * kpts2[:, 0] / 2, h2 * kpts2[:, 1] / 2), axis=-1 + ) + try: + threshold = 1.0 + norm_threshold = threshold / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])) + ) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1[:2, :2], + K2[:2, :2], + norm_threshold, + conf=0.9999999, + ) + T1_to_2 = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2, R, t) + e_pose = max(e_t, e_R) + except: + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3febe5ca9e3a683bc7122cec635c4f54b66f7c --- /dev/null +++ b/imcui/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py @@ -0,0 +1,114 @@ +from PIL import Image +import numpy as np + +import os + +from tqdm import tqdm +from dkm.utils import pose_auc +import cv2 + + +class HpatchesHomogBenchmark: + """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]""" + + def __init__(self, dataset_path) -> None: + seqs_dir = "hpatches-sequences-release" + self.seqs_path = os.path.join(dataset_path, seqs_dir) + self.seq_names = sorted(os.listdir(self.seqs_path)) + # Ignore seqs is same as LoFTR. + self.ignore_seqs = set( + [ + "i_contruction", + "i_crownnight", + "i_dc", + "i_pencils", + "i_whitebuilding", + "v_artisans", + "v_astronautis", + "v_talent", + ] + ) + + def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup): + offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think) + query_coords = ( + np.stack( + ( + wq * (query_coords[..., 0] + 1) / 2, + hq * (query_coords[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + query_to_support = ( + np.stack( + ( + wsup * (query_to_support[..., 0] + 1) / 2, + hsup * (query_to_support[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + return query_coords, query_to_support + + def benchmark(self, model, model_name = None): + n_matches = [] + homog_dists = [] + for seq_idx, seq_name in tqdm( + enumerate(self.seq_names), total=len(self.seq_names) + ): + if seq_name in self.ignore_seqs: + continue + im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm") + im1 = Image.open(im1_path) + w1, h1 = im1.size + for im_idx in range(2, 7): + im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm") + im2 = Image.open(im2_path) + w2, h2 = im2.size + H = np.loadtxt( + os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) + ) + dense_matches, dense_certainty = model.match( + im1_path, im2_path + ) + good_matches, _ = model.sample(dense_matches, dense_certainty, 5000) + pos_a, pos_b = self.convert_coordinates( + good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2 + ) + try: + H_pred, inliers = cv2.findHomography( + pos_a, + pos_b, + method = cv2.RANSAC, + confidence = 0.99999, + ransacReprojThreshold = 3 * min(w2, h2) / 480, + ) + except: + H_pred = None + if H_pred is None: + H_pred = np.zeros((3, 3)) + H_pred[2, 2] = 1.0 + corners = np.array( + [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]] + ) + real_warped_corners = np.dot(corners, np.transpose(H)) + real_warped_corners = ( + real_warped_corners[:, :2] / real_warped_corners[:, 2:] + ) + warped_corners = np.dot(corners, np.transpose(H_pred)) + warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] + mean_dist = np.mean( + np.linalg.norm(real_warped_corners - warped_corners, axis=1) + ) / (min(w2, h2) / 480.0) + homog_dists.append(mean_dist) + n_matches = np.array(n_matches) + thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + auc = pose_auc(np.array(homog_dists), thresholds) + return { + "hpatches_homog_auc_3": auc[2], + "hpatches_homog_auc_5": auc[4], + "hpatches_homog_auc_10": auc[9], + } diff --git a/imcui/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1193745ff18d239165aeb3376642fb17033874 --- /dev/null +++ b/imcui/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py @@ -0,0 +1,124 @@ +import numpy as np +import torch +from dkm.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F + +class Megadepth1500Benchmark: + def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model): + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + for scene_ind in range(len(self.scenes)): + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + im1_path = f"{data_root}/{im_paths[idx1]}" + im2_path = f"{data_root}/{im_paths[idx2]}" + im1 = Image.open(im1_path) + w1, h1 = im1.size + im2 = Image.open(im2_path) + w2, h2 = im2.size + scale1 = 1200 / max(w1, h1) + scale2 = 1200 / max(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1[:2] = K1[:2] * scale1 + K2[:2] = K2[:2] * scale2 + dense_matches, dense_certainty = model.match(im1_path, im2_path) + sparse_matches,_ = model.sample( + dense_matches, dense_certainty, 5000 + ) + kpts1 = sparse_matches[:, :2] + kpts1 = ( + torch.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2, + h1 * (kpts1[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + torch.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2, + h2 * (kpts2[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + norm_threshold = 0.5 / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1.cpu().numpy(), + kpts2.cpu().numpy(), + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..0b370644497efd62563105e68e692e10ff339669 --- /dev/null +++ b/imcui/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py @@ -0,0 +1,86 @@ +import torch +import numpy as np +import tqdm +from dkm.datasets import MegadepthBuilder +from dkm.utils import warp_kpts +from torch.utils.data import ConcatDataset + + +class MegadepthDenseBenchmark: + def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None: + mega = MegadepthBuilder(data_root=data_root) + self.dataset = ConcatDataset( + mega.build_scenes(split="test_loftr", ht=h, wt=w) + ) # fixed resolution of 384,512 + self.num_samples = num_samples + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device + + def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches): + b, h1, w1, d = dense_matches.shape + with torch.no_grad(): + x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2) + # x1 = torch.stack((2*x1[...,0]/w1-1,2*x1[...,1]/h1-1),dim=-1) + mask, x2 = warp_kpts( + x1.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + ) + x2 = torch.stack( + (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1 + ) + prob = mask.float().reshape(b, h1, w1) + x2_hat = dense_matches[..., 2:] + x2_hat = torch.stack( + (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1 + ) + gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1) + gd = gd[prob == 1] + pck_1 = (gd < 1.0).float().mean() + pck_3 = (gd < 3.0).float().mean() + pck_5 = (gd < 5.0).float().mean() + gd = gd.mean() + return gd, pck_1, pck_3, pck_5 + + def benchmark(self, model, batch_size=8): + model.train(False) + with torch.no_grad(): + gd_tot = 0.0 + pck_1_tot = 0.0 + pck_3_tot = 0.0 + pck_5_tot = 0.0 + sampler = torch.utils.data.WeightedRandomSampler( + torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples + ) + dataloader = torch.utils.data.DataLoader( + self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler + ) + for data in tqdm.tqdm(dataloader): + im1, im2, depth1, depth2, T_1to2, K1, K2 = ( + data["query"], + data["support"], + data["query_depth"].to(self.device), + data["support_depth"].to(self.device), + data["T_1to2"].to(self.device), + data["K1"].to(self.device), + data["K2"].to(self.device), + ) + matches, certainty = model.match(im1, im2, batched=True) + gd, pck_1, pck_3, pck_5 = self.geometric_dist( + depth1, depth2, T_1to2, K1, K2, matches + ) + gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = ( + gd_tot + gd, + pck_1_tot + pck_1, + pck_3_tot + pck_3, + pck_5_tot + pck_5, + ) + return { + "mega_pck_1": pck_1_tot.item() / len(dataloader), + "mega_pck_3": pck_3_tot.item() / len(dataloader), + "mega_pck_5": pck_5_tot.item() / len(dataloader), + } diff --git a/imcui/third_party/DKM/dkm/benchmarks/scannet_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/scannet_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..ca938cb462c351845ce035f8be0714cf81214452 --- /dev/null +++ b/imcui/third_party/DKM/dkm/benchmarks/scannet_benchmark.py @@ -0,0 +1,143 @@ +import os.path as osp +import numpy as np +import torch +from dkm.utils import * +from PIL import Image +from tqdm import tqdm + + +class ScanNetBenchmark: + def __init__(self, data_root="data/scannet") -> None: + self.data_root = data_root + + def benchmark(self, model, model_name = None): + model.train(False) + with torch.no_grad(): + data_root = self.data_root + tmp = np.load(osp.join(data_root, "test.npz")) + pairs, rel_pose = tmp["name"], tmp["rel_pose"] + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + pair_inds = np.random.choice( + range(len(pairs)), size=len(pairs), replace=False + ) + for pairind in tqdm(pair_inds, smoothing=0.9): + scene = pairs[pairind] + scene_name = f"scene0{scene[0]}_00" + im1_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[2]}.jpg", + ) + im1 = Image.open(im1_path) + im2_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[3]}.jpg", + ) + im2 = Image.open(im2_path) + T_gt = rel_pose[pairind].reshape(3, 4) + R, t = T_gt[:3, :3], T_gt[:3, 3] + K = np.stack( + [ + np.array([float(i) for i in r.split()]) + for r in open( + osp.join( + self.data_root, + "scans_test", + scene_name, + "intrinsic", + "intrinsic_color.txt", + ), + "r", + ) + .read() + .split("\n") + if r + ] + ) + w1, h1 = im1.size + w2, h2 = im2.size + K1 = K.copy() + K2 = K.copy() + dense_matches, dense_certainty = model.match(im1_path, im2_path) + sparse_matches, sparse_certainty = model.sample( + dense_matches, dense_certainty, 5000 + ) + scale1 = 480 / min(w1, h1) + scale2 = 480 / min(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1 = K1 * scale1 + K2 = K2 * scale2 + + offset = 0.5 + kpts1 = sparse_matches[:, :2] + kpts1 = ( + np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2 - offset, + h1 * (kpts1[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2 - offset, + h2 * (kpts2[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + norm_threshold = 0.5 / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/DKM/dkm/checkpointing/__init__.py b/imcui/third_party/DKM/dkm/checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577 --- /dev/null +++ b/imcui/third_party/DKM/dkm/checkpointing/__init__.py @@ -0,0 +1 @@ +from .checkpoint import CheckPoint diff --git a/imcui/third_party/DKM/dkm/checkpointing/checkpoint.py b/imcui/third_party/DKM/dkm/checkpointing/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..715eeb587ebb87ed0d1bcf9940e048adbe35cde2 --- /dev/null +++ b/imcui/third_party/DKM/dkm/checkpointing/checkpoint.py @@ -0,0 +1,31 @@ +import os +import torch +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +from loguru import logger + + +class CheckPoint: + def __init__(self, dir=None, name="tmp"): + self.name = name + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + + def __call__( + self, + model, + optimizer, + lr_scheduler, + n, + ): + assert model is not None + if isinstance(model, (DataParallel, DistributedDataParallel)): + model = model.module + states = { + "model": model.state_dict(), + "n": n, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + } + torch.save(states, self.dir + self.name + f"_latest.pth") + logger.info(f"Saved states {list(states.keys())}, at step {n}") diff --git a/imcui/third_party/DKM/dkm/datasets/__init__.py b/imcui/third_party/DKM/dkm/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b81083212edaf345c30f0cb1116c5f9de284ce6 --- /dev/null +++ b/imcui/third_party/DKM/dkm/datasets/__init__.py @@ -0,0 +1 @@ +from .megadepth import MegadepthBuilder diff --git a/imcui/third_party/DKM/dkm/datasets/megadepth.py b/imcui/third_party/DKM/dkm/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..c580607e910ce1926b7711b5473aa82b20865369 --- /dev/null +++ b/imcui/third_party/DKM/dkm/datasets/megadepth.py @@ -0,0 +1,177 @@ +import os +import random +from PIL import Image +import h5py +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader, ConcatDataset + +from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +import torchvision.transforms.functional as tvf +from dkm.utils.transforms import GeometricSequential +import kornia.augmentation as K + + +class MegadepthScene: + def __init__( + self, + data_root, + scene_info, + ht=384, + wt=512, + min_overlap=0.0, + shake_t=0, + rot_prob=0.0, + normalize=True, + ) -> None: + self.data_root = data_root + self.image_paths = scene_info["image_paths"] + self.depth_paths = scene_info["depth_paths"] + self.intrinsics = scene_info["intrinsics"] + self.poses = scene_info["poses"] + self.pairs = scene_info["pairs"] + self.overlaps = scene_info["overlaps"] + threshold = self.overlaps > min_overlap + self.pairs = self.pairs[threshold] + self.overlaps = self.overlaps[threshold] + if len(self.pairs) > 100000: + pairinds = np.random.choice( + np.arange(0, len(self.pairs)), 100000, replace=False + ) + self.pairs = self.pairs[pairinds] + self.overlaps = self.overlaps[pairinds] + # counts, bins = np.histogram(self.overlaps,20) + # print(counts) + self.im_transform_ops = get_tuple_transform_ops( + resize=(ht, wt), normalize=normalize + ) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt), normalize=False + ) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) + + def load_im(self, im_ref, crop=None): + im = Image.open(im_ref) + return im + + def load_depth(self, depth_ref, crop=None): + depth = np.array(h5py.File(depth_ref, "r")["depth"]) + return torch.from_numpy(depth) + + def __len__(self): + return len(self.pairs) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K + + def rand_shake(self, *things): + t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2) + return [ + tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0]) + for thing in things + ], t + + def __getitem__(self, pair_idx): + # read intrinsics of original size + idx1, idx2 = self.pairs[pair_idx] + K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3) + K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T1 = self.poses[idx1] + T2 = self.poses[idx2] + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) + + # Load positive pair data + im1, im2 = self.image_paths[idx1], self.image_paths[idx2] + depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2] + im_src_ref = os.path.join(self.data_root, im1) + im_pos_ref = os.path.join(self.data_root, im2) + depth_src_ref = os.path.join(self.data_root, depth1) + depth_pos_ref = os.path.join(self.data_root, depth2) + # return torch.randn((1000,1000)) + im_src = self.load_im(im_src_ref) + im_pos = self.load_im(im_pos_ref) + depth_src = self.load_depth(depth_src_ref) + depth_pos = self.load_depth(depth_pos_ref) + + # Recompute camera intrinsic matrix due to the resize + K1 = self.scale_intrinsic(K1, im_src.width, im_src.height) + K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height) + # Process images + im_src, im_pos = self.im_transform_ops((im_src, im_pos)) + depth_src, depth_pos = self.depth_transform_ops( + (depth_src[None, None], depth_pos[None, None]) + ) + [im_src, im_pos, depth_src, depth_pos], t = self.rand_shake( + im_src, im_pos, depth_src, depth_pos + ) + im_src, Hq = self.H_generator(im_src[None]) + depth_src = self.H_generator.apply_transform(depth_src, Hq) + K1[:2, 2] += t + K2[:2, 2] += t + K1 = Hq[0] @ K1 + data_dict = { + "query": im_src[0], + "query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], + "support": im_pos, + "support_identifier": self.image_paths[idx2] + .split("/")[-1] + .split(".jpg")[0], + "query_depth": depth_src[0, 0], + "support_depth": depth_pos[0, 0], + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + } + return data_dict + + +class MegadepthBuilder: + def __init__(self, data_root="data/megadepth") -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root, "prep_scene_info") + self.all_scenes = os.listdir(self.scene_info_root) + self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] + self.test_scenes_loftr = ["0015.npy", "0022.npy"] + + def build_scenes(self, split="train", min_overlap=0.0, **kwargs): + if split == "train": + scene_names = set(self.all_scenes) - set(self.test_scenes) + elif split == "train_loftr": + scene_names = set(self.all_scenes) - set(self.test_scenes_loftr) + elif split == "test": + scene_names = self.test_scenes + elif split == "test_loftr": + scene_names = self.test_scenes_loftr + else: + raise ValueError(f"Split {split} not available") + scenes = [] + for scene_name in scene_names: + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ).item() + scenes.append( + MegadepthScene( + self.data_root, scene_info, min_overlap=min_overlap, **kwargs + ) + ) + return scenes + + def weight_scenes(self, concat_dataset, alpha=0.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) + return ws + + +if __name__ == "__main__": + mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train")) + mega_test[0] diff --git a/imcui/third_party/DKM/dkm/datasets/scannet.py b/imcui/third_party/DKM/dkm/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac39b41480f7585c4755cc30e0677ef74ed5e0c --- /dev/null +++ b/imcui/third_party/DKM/dkm/datasets/scannet.py @@ -0,0 +1,151 @@ +import os +import random +from PIL import Image +import cv2 +import h5py +import numpy as np +import torch +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset) + +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +import os.path as osp +import matplotlib.pyplot as plt +from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +from dkm.utils.transforms import GeometricSequential + +from tqdm import tqdm + +class ScanNetScene: + def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None: + self.scene_root = osp.join(data_root,"scans","scans_train") + self.data_names = scene_info['name'] + self.overlaps = scene_info['score'] + # Only sample 10s + valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0 + self.overlaps = self.overlaps[valid] + self.data_names = self.data_names[valid] + if len(self.data_names) > 10000: + pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False) + self.data_names = self.data_names[pairinds] + self.overlaps = self.overlaps[pairinds] + self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) + self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) + + def load_im(self, im_ref, crop=None): + im = Image.open(im_ref) + return im + + def load_depth(self, depth_ref, crop=None): + depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + def __len__(self): + return len(self.data_names) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], + [0, sy, 0], + [0, 0, 1]]) + return sK@K + + def read_scannet_pose(self,path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = np.linalg.inv(cam2world) + return world2cam + + + def read_scannet_intrinsic(self,path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return intrinsic[:-1, :-1] + + def __getitem__(self, pair_idx): + # read intrinsics of original size + data_name = self.data_names[pair_idx] + scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the intrinsic of depthmap + K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root, + scene_name, + 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter + # read and compute relative poses + T1 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_1}.txt')) + T2 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_2}.txt')) + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4) + + # Load positive pair data + im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg') + im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg') + depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png') + depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png') + + im_src = self.load_im(im_src_ref) + im_pos = self.load_im(im_pos_ref) + depth_src = self.load_depth(depth_src_ref) + depth_pos = self.load_depth(depth_pos_ref) + + # Recompute camera intrinsic matrix due to the resize + K1 = self.scale_intrinsic(K1, im_src.width, im_src.height) + K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height) + # Process images + im_src, im_pos = self.im_transform_ops((im_src, im_pos)) + depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None])) + + data_dict = {'query': im_src, + 'support': im_pos, + 'query_depth': depth_src[0,0], + 'support_depth': depth_pos[0,0], + 'K1': K1, + 'K2': K2, + 'T_1to2':T_1to2, + } + return data_dict + + +class ScanNetBuilder: + def __init__(self, data_root = 'data/scannet') -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root,'scannet_indices') + self.all_scenes = os.listdir(self.scene_info_root) + + def build_scenes(self, split = 'train', min_overlap=0., **kwargs): + # Note: split doesn't matter here as we always use same scannet_train scenes + scene_names = self.all_scenes + scenes = [] + for scene_name in tqdm(scene_names): + scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True) + scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs)) + return scenes + + def weight_scenes(self, concat_dataset, alpha=.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n)/n**alpha for n in ns]) + return ws + + +if __name__ == "__main__": + mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train')) + mega_test[0] \ No newline at end of file diff --git a/imcui/third_party/DKM/dkm/losses/__init__.py b/imcui/third_party/DKM/dkm/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71914f50d891079d204a07c57367159888f892de --- /dev/null +++ b/imcui/third_party/DKM/dkm/losses/__init__.py @@ -0,0 +1 @@ +from .depth_match_regression_loss import DepthRegressionLoss diff --git a/imcui/third_party/DKM/dkm/losses/depth_match_regression_loss.py b/imcui/third_party/DKM/dkm/losses/depth_match_regression_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..80da70347b4b4addc721e2a14ed489f8683fd48a --- /dev/null +++ b/imcui/third_party/DKM/dkm/losses/depth_match_regression_loss.py @@ -0,0 +1,128 @@ +from einops.einops import rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F +from dkm.utils.utils import warp_kpts + + +class DepthRegressionLoss(nn.Module): + def __init__( + self, + robust=True, + center_coords=False, + scale_normalize=False, + ce_weight=0.01, + local_loss=True, + local_dist=4.0, + local_largest_scale=8, + ): + super().__init__() + self.robust = robust # measured in pixels + self.center_coords = center_coords + self.scale_normalize = scale_normalize + self.ce_weight = ce_weight + self.local_loss = local_loss + self.local_dist = local_dist + self.local_largest_scale = local_largest_scale + + def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale): + """[summary] + + Args: + H ([type]): [description] + scale ([type]): [description] + + Returns: + [type]: [description] + """ + b, h1, w1, d = dense_matches.shape + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device + ) + for n in (b, h1, w1) + ] + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + ) + prob = mask.float().reshape(b, h1, w1) + gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale? + return gd, prob + + def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8): + """[summary] + + Args: + dense_certainty ([type]): [description] + prob ([type]): [description] + eps ([type], optional): [description]. Defaults to 1e-8. + + Returns: + [type]: [description] + """ + smooth_prob = prob + ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob) + depth_loss = gd[prob > 0] + if not torch.any(prob > 0).item(): + depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere + return { + f"ce_loss_{scale}": ce_loss.mean(), + f"depth_loss_{scale}": depth_loss.mean(), + } + + def forward(self, dense_corresps, batch): + """[summary] + + Args: + out ([type]): [description] + batch ([type]): [description] + + Returns: + [type]: [description] + """ + scales = list(dense_corresps.keys()) + tot_loss = 0.0 + prev_gd = 0.0 + for scale in scales: + dense_scale_corresps = dense_corresps[scale] + dense_scale_certainty, dense_scale_coords = ( + dense_scale_corresps["dense_certainty"], + dense_scale_corresps["dense_flow"], + ) + dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d") + b, h, w, d = dense_scale_coords.shape + gd, prob = self.geometric_dist( + batch["query_depth"], + batch["support_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + dense_scale_coords, + scale, + ) + if ( + scale <= self.local_largest_scale and self.local_loss + ): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching + prob = prob * ( + F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0] + < (2 / 512) * (self.local_dist * scale) + ) + depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale) + scale_loss = ( + self.ce_weight * depth_losses[f"ce_loss_{scale}"] + + depth_losses[f"depth_loss_{scale}"] + ) # scale ce loss for coarser scales + if self.scale_normalize: + scale_loss = scale_loss * 1 / scale + tot_loss = tot_loss + scale_loss + prev_gd = gd.detach() + return tot_loss diff --git a/imcui/third_party/DKM/dkm/models/__init__.py b/imcui/third_party/DKM/dkm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4fc321ec70fd116beca23e94248cb6bbe771523 --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/__init__.py @@ -0,0 +1,4 @@ +from .model_zoo import ( + DKMv3_outdoor, + DKMv3_indoor, +) diff --git a/imcui/third_party/DKM/dkm/models/deprecated/build_model.py b/imcui/third_party/DKM/dkm/models/deprecated/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dd28335f3e348ab6c90b26ba91b95e864b0bbbb9 --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/deprecated/build_model.py @@ -0,0 +1,787 @@ +import torch +import torch.nn as nn +from dkm import * +from .local_corr import LocalCorr +from .corr_channels import NormedCorr +from torchvision.models import resnet as tv_resnet + +dkm_pretrained_urls = { + "DKM": { + "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth", + "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth", + }, + "DKMv2":{ + "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth", + } +} + + +def DKM(pretrained=True, version="mega_synthetic", device=None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": ConvRefiner( + 2 * 256, + 512, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": ConvRefiner( + 2 * 64, + 128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained), + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["DKM"][version] + ) + matcher.load_state_dict(weights) + return matcher + +def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs): + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128, + 1024+128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + ), + "8": ConvRefiner( + 2 * 512+64, + 1024+64, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + ), + "4": ConvRefiner( + 2 * 256+32, + 512+32, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + ), + "1": ConvRefiner( + 2 * 3+6, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=6, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + if resolution == "low": + h, w = 384, 512 + elif resolution == "high": + h, w = 480, 640 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained), + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device) + if pretrained: + try: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["DKMv2"][version] + ) + except: + weights = torch.load( + dkm_pretrained_urls["DKMv2"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def local_corr(pretrained=True, version="mega_synthetic"): + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": LocalCorr( + 81, + 81 * 6, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": LocalCorr( + 81, + 81, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["local_corr"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def corr_channels(pretrained=True, version="mega_synthetic"): + h, w = 384, 512 + gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16) + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim[0] + feat_dim, dfn_dim), + "16": RRB(gp_dim[1] + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": ConvRefiner( + 2 * 256, + 512, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": ConvRefiner( + 2 * 64, + 128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + gp32 = NormedCorr() + gp16 = NormedCorr() + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["corr_channels"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def baseline(pretrained=True, version="mega_synthetic"): + h, w = 384, 512 + gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16) + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim[0] + feat_dim, dfn_dim), + "16": RRB(gp_dim[1] + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": LocalCorr( + 81, + 81 * 12, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": LocalCorr( + 81, + 81 * 6, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": LocalCorr( + 81, + 81, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + gp32 = NormedCorr() + gp16 = NormedCorr() + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["baseline"][version] + ) + matcher.load_state_dict(weights) + return matcher + + +def linear(pretrained=True, version="mega_synthetic"): + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "8": ConvRefiner( + 2 * 512, + 1024, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "4": ConvRefiner( + 2 * 256, + 512, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "2": ConvRefiner( + 2 * 64, + 128, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + "1": ConvRefiner( + 2 * 3, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "linear" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + h, w = 384, 512 + encoder = Encoder( + tv_resnet.resnet50(pretrained=not pretrained) + ) # only load pretrained weights if not loading a pretrained matcher ;) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) + if pretrained: + weights = torch.hub.load_state_dict_from_url( + dkm_pretrained_urls["linear"][version] + ) + matcher.load_state_dict(weights) + return matcher diff --git a/imcui/third_party/DKM/dkm/models/deprecated/corr_channels.py b/imcui/third_party/DKM/dkm/models/deprecated/corr_channels.py new file mode 100644 index 0000000000000000000000000000000000000000..8713b0d8c7a0ce91da4d2105ba29097a4969a037 --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/deprecated/corr_channels.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class NormedCorrelationKernel(nn.Module): # similar to softmax kernel + def __init__(self): + super().__init__() + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + return c + + +class NormedCorr(nn.Module): + def __init__( + self, + ): + super().__init__() + self.corr = NormedCorrelationKernel() + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def forward(self, x, y, **kwargs): + b, c, h, w = y.shape + assert x.shape == y.shape + x, y = self.reshape(x), self.reshape(y) + corr_xy = self.corr(x, y) + corr_xy_flat = rearrange(corr_xy, "b (h w) c -> b c h w", h=h, w=w) + return corr_xy_flat diff --git a/imcui/third_party/DKM/dkm/models/deprecated/local_corr.py b/imcui/third_party/DKM/dkm/models/deprecated/local_corr.py new file mode 100644 index 0000000000000000000000000000000000000000..681fe4c0079561fa7a4c44e82a8879a4a27273a1 --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/deprecated/local_corr.py @@ -0,0 +1,630 @@ +import torch +import torch.nn.functional as F + +try: + import cupy +except: + print("Cupy not found, local correlation will not work") +import re +from ..dkm import ConvRefiner + + +class Stream: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device == 'cuda': + stream = torch.cuda.current_stream(device=device).cuda_stream + else: + stream = None + + +kernel_Correlation_rearrange = """ + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + if (intIndex >= n) { + return; + } + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + __syncthreads(); + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue; + } +""" + +kernel_Correlation_updateOutput = """ + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + float *patch_data = (float *)patch_data_char; + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + __syncthreads(); + __shared__ float sum[32]; + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + __syncthreads(); + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +""" + +kernel_Correlation_updateGradFirst = """ + #define ROUND_OFF 50000 + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +""" + +kernel_Correlation_updateGradSecond = """ + #define ROUND_OFF 50000 + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +""" + + +def cupy_kernel(strFunction, objectVariables): + strKernel = globals()[strFunction] + + while True: + objectMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) + + if objectMatch is None: + break + + intArg = int(objectMatch.group(2)) + + strTensor = objectMatch.group(4) + intSizes = objectVariables[strTensor].size() + + strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg])) + + while True: + objectMatch = re.search(r"(VALUE_)([0-4])(\()([^\)]+)(\))", strKernel) + + if objectMatch is None: + break + + intArgs = int(objectMatch.group(2)) + strArgs = objectMatch.group(4).split(",") + + strTensor = strArgs[0] + intStrides = objectVariables[strTensor].stride() + strIndex = [ + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str(intStrides[intArg]) + + ")" + for intArg in range(intArgs) + ] + + strKernel = strKernel.replace( + objectMatch.group(0), strTensor + "[" + str.join("+", strIndex) + "]" + ) + + return strKernel + + +try: + + @cupy.memoize(for_each_device=True) + def cupy_launch(strFunction, strKernel): + return cupy.RawModule(code=strKernel).get_function(strFunction) + +except: + pass + + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros( + [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)] + ) + rbot1 = first.new_zeros( + [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)] + ) + + self.save_for_backward(first, second, rbot0, rbot1) + + first = first.contiguous() + second = second.contiguous() + + output = first.new_zeros([first.size(0), 81, first.size(2), first.size(3)]) + + if first.is_cuda == True: + n = first.size(2) * first.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", {"input": first, "output": rbot0} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()], + stream=Stream, + ) + + n = second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", {"input": second, "output": rbot1} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()], + stream=Stream, + ) + + n = output.size(1) * output.size(2) * output.size(3) + cupy_launch( + "kernel_Correlation_updateOutput", + cupy_kernel( + "kernel_Correlation_updateOutput", + {"rbot0": rbot0, "rbot1": rbot1, "top": output}, + ), + )( + grid=tuple([output.size(3), output.size(2), output.size(0)]), + block=tuple([32, 1, 1]), + shared_mem=first.size(1) * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()], + stream=Stream, + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + return output + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous() + + assert gradOutput.is_contiguous() == True + + gradFirst = ( + first.new_zeros( + [first.size(0), first.size(1), first.size(2), first.size(3)] + ) + if self.needs_input_grad[0] == True + else None + ) + gradSecond = ( + first.new_zeros( + [first.size(0), first.size(1), first.size(2), first.size(3)] + ) + if self.needs_input_grad[1] == True + else None + ) + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.size(0)): + n = first.size(1) * first.size(2) * first.size(3) + cupy_launch( + "kernel_Correlation_updateGradFirst", + cupy_kernel( + "kernel_Correlation_updateGradFirst", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": gradOutput, + "gradFirst": gradFirst, + "gradSecond": None, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), + gradFirst.data_ptr(), + None, + ], + stream=Stream, + ) + + if gradSecond is not None: + for intSample in range(first.size(0)): + n = first.size(1) * first.size(2) * first.size(3) + cupy_launch( + "kernel_Correlation_updateGradSecond", + cupy_kernel( + "kernel_Correlation_updateGradSecond", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": gradOutput, + "gradFirst": None, + "gradSecond": gradSecond, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), + None, + gradSecond.data_ptr(), + ], + stream=Stream, + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + return gradFirst, gradSecond + + +class _FunctionCorrelationTranspose(torch.autograd.Function): + @staticmethod + def forward(self, input, second): + rbot0 = second.new_zeros( + [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)] + ) + rbot1 = second.new_zeros( + [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)] + ) + + self.save_for_backward(input, second, rbot0, rbot1) + + input = input.contiguous() + second = second.contiguous() + + output = second.new_zeros( + [second.size(0), second.size(1), second.size(2), second.size(3)] + ) + + if second.is_cuda == True: + n = second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", {"input": second, "output": rbot1} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()], + stream=Stream, + ) + + for intSample in range(second.size(0)): + n = second.size(1) * second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_updateGradFirst", + cupy_kernel( + "kernel_Correlation_updateGradFirst", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": input, + "gradFirst": output, + "gradSecond": None, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + input.data_ptr(), + output.data_ptr(), + None, + ], + stream=Stream, + ) + + elif second.is_cuda == False: + raise NotImplementedError() + + return output + + @staticmethod + def backward(self, gradOutput): + input, second, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous() + + gradInput = ( + input.new_zeros( + [input.size(0), input.size(1), input.size(2), input.size(3)] + ) + if self.needs_input_grad[0] == True + else None + ) + gradSecond = ( + second.new_zeros( + [second.size(0), second.size(1), second.size(2), second.size(3)] + ) + if self.needs_input_grad[1] == True + else None + ) + + if second.is_cuda == True: + if gradInput is not None or gradSecond is not None: + n = second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_rearrange", + cupy_kernel( + "kernel_Correlation_rearrange", + {"input": gradOutput, "output": rbot0}, + ), + )( + grid=tuple( + [int((n + 16 - 1) / 16), gradOutput.size(1), gradOutput.size(0)] + ), + block=tuple([16, 1, 1]), + args=[n, gradOutput.data_ptr(), rbot0.data_ptr()], + stream=Stream, + ) + + if gradInput is not None: + n = gradInput.size(1) * gradInput.size(2) * gradInput.size(3) + cupy_launch( + "kernel_Correlation_updateOutput", + cupy_kernel( + "kernel_Correlation_updateOutput", + {"rbot0": rbot0, "rbot1": rbot1, "top": gradInput}, + ), + )( + grid=tuple( + [gradInput.size(3), gradInput.size(2), gradInput.size(0)] + ), + block=tuple([32, 1, 1]), + shared_mem=gradOutput.size(1) * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), gradInput.data_ptr()], + stream=Stream, + ) + + if gradSecond is not None: + for intSample in range(second.size(0)): + n = second.size(1) * second.size(2) * second.size(3) + cupy_launch( + "kernel_Correlation_updateGradSecond", + cupy_kernel( + "kernel_Correlation_updateGradSecond", + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": input, + "gradFirst": None, + "gradSecond": gradSecond, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + input.data_ptr(), + None, + gradSecond.data_ptr(), + ], + stream=Stream, + ) + + elif second.is_cuda == False: + raise NotImplementedError() + + return gradInput, gradSecond + + +def FunctionCorrelation(reference_features, query_features): + return _FunctionCorrelation.apply(reference_features, query_features) + + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + def forward(self, tensorFirst, tensorSecond): + return _FunctionCorrelation.apply(tensorFirst, tensorSecond) + + +def FunctionCorrelationTranspose(reference_features, query_features): + return _FunctionCorrelationTranspose.apply(reference_features, query_features) + + +class ModuleCorrelationTranspose(torch.nn.Module): + def __init__(self): + super(ModuleCorrelationTranspose, self).__init__() + + def forward(self, tensorFirst, tensorSecond): + return _FunctionCorrelationTranspose.apply(tensorFirst, tensorSecond) + + +class LocalCorr(ConvRefiner): + def forward(self, x, y, flow): + """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them + + Args: + x ([type]): [description] + y ([type]): [description] + flow ([type]): [description] + + Returns: + [type]: [description] + """ + with torch.no_grad(): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) + corr = FunctionCorrelation(x, x_hat) + d = self.block1(corr) + d = self.hidden_blocks(d) + d = self.out_conv(d) + certainty, displacement = d[:, :-2], d[:, -2:] + return certainty, displacement + + +if __name__ == "__main__": + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + x = torch.randn(2, 128, 32, 32).to(device) + y = torch.randn(2, 128, 32, 32).to(device) + local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4) + z = local_corr(x, y) + print("hej") diff --git a/imcui/third_party/DKM/dkm/models/dkm.py b/imcui/third_party/DKM/dkm/models/dkm.py new file mode 100644 index 0000000000000000000000000000000000000000..27c3f6d59ad3a8e976e3d719868908ddf443883e --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/dkm.py @@ -0,0 +1,759 @@ +import math +import os +import numpy as np +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils import get_tuple_transform_ops +from einops import rearrange +from ..utils.local_correlation import local_correlation + + +class ConvRefiner(nn.Module): + def __init__( + self, + in_dim=6, + hidden_dim=16, + out_dim=2, + dw=False, + kernel_size=5, + hidden_blocks=3, + displacement_emb = None, + displacement_emb_dim = None, + local_corr_radius = None, + corr_in_other = None, + no_support_fm = False, + ): + super().__init__() + self.block1 = self.create_block( + in_dim, hidden_dim, dw=dw, kernel_size=kernel_size + ) + self.hidden_blocks = nn.Sequential( + *[ + self.create_block( + hidden_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + ) + for hb in range(hidden_blocks) + ] + ) + self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) + if displacement_emb: + self.has_displacement_emb = True + self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + else: + self.has_displacement_emb = False + self.local_corr_radius = local_corr_radius + self.corr_in_other = corr_in_other + self.no_support_fm = no_support_fm + def create_block( + self, + in_dim, + out_dim, + dw=False, + kernel_size=5, + ): + num_groups = 1 if not dw else in_dim + if dw: + assert ( + out_dim % in_dim == 0 + ), "outdim must be divisible by indim for depthwise" + conv1 = nn.Conv2d( + in_dim, + out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=num_groups, + ) + norm = nn.BatchNorm2d(out_dim) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) + return nn.Sequential(conv1, norm, relu, conv2) + + def forward(self, x, y, flow): + """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them + + Args: + x ([type]): [description] + y ([type]): [description] + flow ([type]): [description] + + Returns: + [type]: [description] + """ + device = x.device + b,c,hs,ws = x.shape + with torch.no_grad(): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) + if self.has_displacement_emb: + query_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ) + ) + query_coords = torch.stack((query_coords[1], query_coords[0])) + query_coords = query_coords[None].expand(b, 2, hs, ws) + in_displacement = flow-query_coords + emb_in_displacement = self.disp_emb(in_displacement) + if self.local_corr_radius: + #TODO: should corr have gradient? + if self.corr_in_other: + # Corr in other means take a kxk grid around the predicted coordinate in other image + local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow) + else: + # Otherwise we use the warp to sample in the first image + # This is actually different operations, especially for large viewpoint changes + local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,) + if self.no_support_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) + else: + d = torch.cat((x, x_hat, emb_in_displacement), dim=1) + else: + if self.no_support_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat), dim=1) + d = self.block1(d) + d = self.hidden_blocks(d) + d = self.out_conv(d) + certainty, displacement = d[:, :-2], d[:, -2:] + return certainty, displacement + + +class CosKernel(nn.Module): # similar to softmax kernel + def __init__(self, T, learn_temperature=False): + super().__init__() + self.learn_temperature = learn_temperature + if self.learn_temperature: + self.T = nn.Parameter(torch.tensor(T)) + else: + self.T = T + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + if self.learn_temperature: + T = self.T.abs() + 0.01 + else: + T = torch.tensor(self.T, device=c.device) + K = ((c - 1.0) / T).exp() + return K + + +class CAB(nn.Module): + def __init__(self, in_channels, out_channels): + super(CAB, self).__init__() + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.sigmod = nn.Sigmoid() + + def forward(self, x): + x1, x2 = x # high, low (old, new) + x = torch.cat([x1, x2], dim=1) + x = self.global_pooling(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.sigmod(x) + x2 = x * x2 + res = x2 + x1 + return res + + +class RRB(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(RRB, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + ) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(out_channels) + self.conv3 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + ) + + def forward(self, x): + x = self.conv1(x) + res = self.conv2(x) + res = self.bn(res) + res = self.relu(res) + res = self.conv3(res) + return self.relu(x + res) + + +class DFN(nn.Module): + def __init__( + self, + internal_dim, + feat_input_modules, + pred_input_modules, + rrb_d_dict, + cab_dict, + rrb_u_dict, + use_global_context=False, + global_dim=None, + terminal_module=None, + upsample_mode="bilinear", + align_corners=False, + ): + super().__init__() + if use_global_context: + assert ( + global_dim is not None + ), "Global dim must be provided when using global context" + self.align_corners = align_corners + self.internal_dim = internal_dim + self.feat_input_modules = feat_input_modules + self.pred_input_modules = pred_input_modules + self.rrb_d = rrb_d_dict + self.cab = cab_dict + self.rrb_u = rrb_u_dict + self.use_global_context = use_global_context + if use_global_context: + self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.terminal_module = ( + terminal_module if terminal_module is not None else nn.Identity() + ) + self.upsample_mode = upsample_mode + self._scales = [int(key) for key in self.terminal_module.keys()] + + def scales(self): + return self._scales.copy() + + def forward(self, embeddings, feats, context, key): + feats = self.feat_input_modules[str(key)](feats) + embeddings = torch.cat([feats, embeddings], dim=1) + embeddings = self.rrb_d[str(key)](embeddings) + context = self.cab[str(key)]([context, embeddings]) + context = self.rrb_u[str(key)](context) + preds = self.terminal_module[str(key)](context) + pred_coord = preds[:, -2:] + pred_certainty = preds[:, :-2] + return pred_coord, pred_certainty, context + + +class GP(nn.Module): + def __init__( + self, + kernel, + T=1, + learn_temperature=False, + only_attention=False, + gp_dim=64, + basis="fourier", + covar_size=5, + only_nearest_neighbour=False, + sigma_noise=0.1, + no_cov=False, + predict_features = False, + ): + super().__init__() + self.K = kernel(T=T, learn_temperature=learn_temperature) + self.sigma_noise = sigma_noise + self.covar_size = covar_size + self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) + self.only_attention = only_attention + self.only_nearest_neighbour = only_nearest_neighbour + self.basis = basis + self.no_cov = no_cov + self.dim = gp_dim + self.predict_features = predict_features + + def get_local_cov(self, cov): + K = self.covar_size + b, h, w, h, w = cov.shape + hw = h * w + cov = F.pad(cov, 4 * (K // 2,)) # pad v_q + delta = torch.stack( + torch.meshgrid( + torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) + ), + dim=-1, + ) + positions = torch.stack( + torch.meshgrid( + torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) + ), + dim=-1, + ) + neighbours = positions[:, :, None, None, :] + delta[None, :, :] + points = torch.arange(hw)[:, None].expand(hw, K**2) + local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ + :, + points.flatten(), + neighbours[..., 0].flatten(), + neighbours[..., 1].flatten(), + ].reshape(b, h, w, K**2) + return local_cov + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def project_to_basis(self, x): + if self.basis == "fourier": + return torch.cos(8 * math.pi * self.pos_conv(x)) + elif self.basis == "linear": + return self.pos_conv(x) + else: + raise ValueError( + "No other bases other than fourier and linear currently supported in public release" + ) + + def get_pos_enc(self, y): + b, c, h, w = y.shape + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), + ) + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.project_to_basis(coarse_coords) + return coarse_embedded_coords + + def forward(self, x, y, **kwargs): + b, c, h1, w1 = x.shape + b, c, h2, w2 = y.shape + f = self.get_pos_enc(y) + if self.predict_features: + f = f + y[:,:self.dim] # Stupid way to predict features + b, d, h2, w2 = f.shape + #assert x.shape == y.shape + x, y, f = self.reshape(x), self.reshape(y), self.reshape(f) + K_xx = self.K(x, x) + K_yy = self.K(y, y) + K_xy = self.K(x, y) + K_yx = K_xy.permute(0, 2, 1) + sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] + # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large + if len(K_yy[0]) > 2000: + K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)]) + else: + K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) + + mu_x = K_xy.matmul(K_yy_inv.matmul(f)) + mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) + if not self.no_cov: + cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) + cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + local_cov_x = self.get_local_cov(cov_x) + local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") + gp_feats = torch.cat((mu_x, local_cov_x), dim=1) + else: + gp_feats = mu_x + return gp_feats + + +class Encoder(nn.Module): + def __init__(self, resnet): + super().__init__() + self.resnet = resnet + def forward(self, x): + x0 = x + b, c, h, w = x.shape + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x1 = self.resnet.relu(x) + + x = self.resnet.maxpool(x1) + x2 = self.resnet.layer1(x) + + x3 = self.resnet.layer2(x2) + + x4 = self.resnet.layer3(x3) + + x5 = self.resnet.layer4(x4) + feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0} + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + +class Decoder(nn.Module): + def __init__( + self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None, + ): + super().__init__() + self.embedding_decoder = embedding_decoder + self.gps = gps + self.proj = proj + self.conv_refiner = conv_refiner + self.detach = detach + if scales == "all": + self.scales = ["32", "16", "8", "4", "2", "1"] + else: + self.scales = scales + + def upsample_preds(self, flow, certainty, query, support): + b, hs, ws, d = flow.shape + b, c, h, w = query.shape + flow = flow.permute(0, 3, 1, 2) + certainty = F.interpolate( + certainty, size=(h, w), align_corners=False, mode="bilinear" + ) + flow = F.interpolate( + flow, size=(h, w), align_corners=False, mode="bilinear" + ) + delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow) + flow = torch.stack( + ( + flow[:, 0] + delta_flow[:, 0] / (4 * w), + flow[:, 1] + delta_flow[:, 1] / (4 * h), + ), + dim=1, + ) + flow = flow.permute(0, 2, 3, 1) + certainty = certainty + delta_certainty + return flow, certainty + + def get_placeholder_flow(self, b, h, w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + return coarse_coords + + + def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None): + coarse_scales = self.embedding_decoder.scales() + all_scales = self.scales if not upsample else ["8", "4", "2", "1"] + sizes = {scale: f1[scale].shape[-2:] for scale in f1} + h, w = sizes[1] + b = f1[1].shape[0] + device = f1[1].device + coarsest_scale = int(all_scales[0]) + old_stuff = torch.zeros( + b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + ) + dense_corresps = {} + if not upsample: + dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) + dense_certainty = 0.0 + else: + dense_flow = F.interpolate( + dense_flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + dense_certainty = F.interpolate( + dense_certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + for new_scale in all_scales: + ins = int(new_scale) + f1_s, f2_s = f1[ins], f2[ins] + if new_scale in self.proj: + f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) + b, c, hs, ws = f1_s.shape + if ins in coarse_scales: + old_stuff = F.interpolate( + old_stuff, size=sizes[ins], mode="bilinear", align_corners=False + ) + new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow) + dense_flow, dense_certainty, old_stuff = self.embedding_decoder( + new_stuff, f1_s, old_stuff, new_scale + ) + + if new_scale in self.conv_refiner: + delta_certainty, displacement = self.conv_refiner[new_scale]( + f1_s, f2_s, dense_flow + ) + dense_flow = torch.stack( + ( + dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w), + dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h), + ), + dim=1, + ) + dense_certainty = ( + dense_certainty + delta_certainty + ) # predict both certainty and displacement + + dense_corresps[ins] = { + "dense_flow": dense_flow, + "dense_certainty": dense_certainty, + } + + if new_scale != "1": + dense_flow = F.interpolate( + dense_flow, + size=sizes[ins // 2], + align_corners=False, + mode="bilinear", + ) + + dense_certainty = F.interpolate( + dense_certainty, + size=sizes[ins // 2], + align_corners=False, + mode="bilinear", + ) + if self.detach: + dense_flow = dense_flow.detach() + dense_certainty = dense_certainty.detach() + return dense_corresps + + +class RegressionMatcher(nn.Module): + def __init__( + self, + encoder, + decoder, + h=384, + w=512, + use_contrastive_loss = False, + alpha = 1, + beta = 0, + sample_mode = "threshold", + upsample_preds = False, + symmetric = False, + name = None, + use_soft_mutual_nearest_neighbours = False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.w_resized = w + self.h_resized = h + self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) + self.use_contrastive_loss = use_contrastive_loss + self.alpha = alpha + self.beta = beta + self.sample_mode = sample_mode + self.upsample_preds = upsample_preds + self.symmetric = symmetric + self.name = name + self.sample_thresh = 0.05 + self.upsample_res = (864,1152) + if use_soft_mutual_nearest_neighbours: + assert symmetric, "MNS requires symmetric inference" + self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours + + def extract_backbone_features(self, batch, batched = True, upsample = True): + #TODO: only extract stride [1,2,4,8] for upsample = True + x_q = batch["query"] + x_s = batch["support"] + if batched: + X = torch.cat((x_q, x_s)) + feature_pyramid = self.encoder(X) + else: + feature_pyramid = self.encoder(x_q), self.encoder(x_s) + return feature_pyramid + + def sample( + self, + dense_matches, + dense_certainty, + num=10000, + ): + if "threshold" in self.sample_mode: + upper_thresh = self.sample_thresh + dense_certainty = dense_certainty.clone() + dense_certainty[dense_certainty > upper_thresh] = 1 + elif "pow" in self.sample_mode: + dense_certainty = dense_certainty**(1/3) + elif "naive" in self.sample_mode: + dense_certainty = torch.ones_like(dense_certainty) + matches, certainty = ( + dense_matches.reshape(-1, 4), + dense_certainty.reshape(-1), + ) + expansion_factor = 4 if "balanced" in self.sample_mode else 1 + good_samples = torch.multinomial(certainty, + num_samples = min(expansion_factor*num, len(certainty)), + replacement=False) + good_matches, good_certainty = matches[good_samples], certainty[good_samples] + if "balanced" not in self.sample_mode: + return good_matches, good_certainty + + from ..utils.kde import kde + density = kde(good_matches, std=0.1) + p = 1 / (density+1) + p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial(p, + num_samples = min(num,len(good_certainty)), + replacement=False) + return good_matches[balanced_samples], good_certainty[balanced_samples] + + def forward(self, batch, batched = True): + feature_pyramid = self.extract_backbone_features(batch, batched=batched) + if batched: + f_q_pyramid = { + scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() + } + f_s_pyramid = { + scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() + } + else: + f_q_pyramid, f_s_pyramid = feature_pyramid + dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid) + if self.training and self.use_contrastive_loss: + return dense_corresps, (f_q_pyramid, f_s_pyramid) + else: + return dense_corresps + + def forward_symmetric(self, batch, upsample = False, batched = True): + feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched) + f_q_pyramid = feature_pyramid + f_s_pyramid = { + scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0])) + for scale, f_scale in feature_pyramid.items() + } + dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {})) + return dense_corresps + + def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): + kpts_A, kpts_B = matches[...,:2], matches[...,2:] + kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1) + kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1) + return kpts_A, kpts_B + + def match( + self, + im1_path, + im2_path, + *args, + batched=False, + device = None + ): + assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False " + if isinstance(im1_path, (str, os.PathLike)): + im1, im2 = Image.open(im1_path), Image.open(im2_path) + else: # assume it is a PIL Image + im1, im2 = im1_path, im2_path + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + symmetric = self.symmetric + self.train(False) + with torch.no_grad(): + if not batched: + b = 1 + w, h = im1.size + w2, h2 = im2.size + # Get images in good format + ws = self.w_resized + hs = self.h_resized + + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True + ) + query, support = test_transform((im1, im2)) + batch = {"query": query[None].to(device), "support": support[None].to(device)} + else: + b, c, h, w = im1.shape + b, c, h2, w2 = im2.shape + assert w == w2 and h == h2, "For batched images we assume same size" + batch = {"query": im1.to(device), "support": im2.to(device)} + hs, ws = self.h_resized, self.w_resized + finest_scale = 1 + # Run matcher + if symmetric: + dense_corresps = self.forward_symmetric(batch, batched = True) + else: + dense_corresps = self.forward(batch, batched = True) + + if self.upsample_preds: + hs, ws = self.upsample_res + low_res_certainty = F.interpolate( + dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + cert_clamp = 0 + factor = 0.5 + low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + + if self.upsample_preds: + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True + ) + query, support = test_transform((im1, im2)) + query, support = query[None].to(device), support[None].to(device) + batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]} + if symmetric: + dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True) + else: + dense_corresps = self.forward(batch, batched = True, upsample=True) + query_to_support = dense_corresps[finest_scale]["dense_flow"] + dense_certainty = dense_corresps[finest_scale]["dense_certainty"] + + # Get certainty interpolation + dense_certainty = dense_certainty - low_res_certainty + query_to_support = query_to_support.permute( + 0, 2, 3, 1 + ) + # Create im1 meshgrid + query_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ) + ) + query_coords = torch.stack((query_coords[1], query_coords[0])) + query_coords = query_coords[None].expand(b, 2, hs, ws) + dense_certainty = dense_certainty.sigmoid() # logits -> probs + query_coords = query_coords.permute(0, 2, 3, 1) + if (query_to_support.abs() > 1).any() and True: + wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0 + dense_certainty[wrong[:,None]] = 0 + + query_to_support = torch.clamp(query_to_support, -1, 1) + if symmetric: + support_coords = query_coords + qts, stq = query_to_support.chunk(2) + q_warp = torch.cat((query_coords, qts), dim=-1) + s_warp = torch.cat((stq, support_coords), dim=-1) + warp = torch.cat((q_warp, s_warp),dim=2) + dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0] + else: + warp = torch.cat((query_coords, query_to_support), dim=-1) + if batched: + return ( + warp, + dense_certainty + ) + else: + return ( + warp[0], + dense_certainty[0], + ) diff --git a/imcui/third_party/DKM/dkm/models/encoders.py b/imcui/third_party/DKM/dkm/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..29077e1797196611e9b59a753130a5b153e0aa05 --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/encoders.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as tvm + +class ResNet18(nn.Module): + def __init__(self, pretrained=False) -> None: + super().__init__() + self.net = tvm.resnet18(pretrained=pretrained) + def forward(self, x): + self = self.net + x1 = x + x = self.conv1(x1) + x = self.bn1(x) + x2 = self.relu(x) + x = self.maxpool(x2) + x4 = self.layer1(x) + x8 = self.layer2(x4) + x16 = self.layer3(x8) + x32 = self.layer4(x16) + return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1} + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + +class ResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None: + super().__init__() + if dilation is None: + dilation = [False,False,False] + if anti_aliased: + pass + else: + if weights is not None: + self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + else: + self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) + + self.high_res = high_res + self.freeze_bn = freeze_bn + def forward(self, x): + net = self.net + feats = {1:x} + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x + x = net.layer2(x) + feats[8] = x + x = net.layer3(x) + feats[16] = x + x = net.layer4(x) + feats[32] = x + return feats + + def train(self, mode=True): + super().train(mode) + if self.freeze_bn: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + + + +class ResNet101(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + super().__init__() + if weights is not None: + self.net = tvm.resnet101(weights = weights) + else: + self.net = tvm.resnet101(pretrained=pretrained) + self.high_res = high_res + self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): + net = self.net + feats = {1:x} + sf = self.scale_factor + if self.high_res: + x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer2(x) + feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer3(x) + feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer4(x) + feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + +class WideResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + super().__init__() + if weights is not None: + self.net = tvm.wide_resnet50_2(weights = weights) + else: + self.net = tvm.wide_resnet50_2(pretrained=pretrained) + self.high_res = high_res + self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): + net = self.net + feats = {1:x} + sf = self.scale_factor + if self.high_res: + x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer2(x) + feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer3(x) + feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer4(x) + feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass \ No newline at end of file diff --git a/imcui/third_party/DKM/dkm/models/model_zoo/DKMv3.py b/imcui/third_party/DKM/dkm/models/model_zoo/DKMv3.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4c9ede3863d778f679a033d8d2287b8776e894 --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/model_zoo/DKMv3.py @@ -0,0 +1,150 @@ +import torch + +from torch import nn +from ..dkm import * +from ..encoders import * + + +def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128+(2*7+1)**2, + 2 * 512+128+(2*7+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + ), + "8": ConvRefiner( + 2 * 512+64+(2*3+1)**2, + 2 * 512+64+(2*3+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + ), + "4": ConvRefiner( + 2 * 256+32+(2*2+1)**2, + 2 * 256+32+(2*2+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + ), + "1": ConvRefiner( + 2 * 3+6, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=6, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + + encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device) + res = matcher.load_state_dict(weights) + return matcher diff --git a/imcui/third_party/DKM/dkm/models/model_zoo/__init__.py b/imcui/third_party/DKM/dkm/models/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85da2920c1acfac140ada2d87623203607d42ca --- /dev/null +++ b/imcui/third_party/DKM/dkm/models/model_zoo/__init__.py @@ -0,0 +1,39 @@ +weight_urls = { + "DKMv3": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth", + }, +} +import torch +from .DKMv3 import DKMv3 + + +def DKMv3_outdoor(path_to_weights = None, device=None): + """ + Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default + resolution can be changed by setting model.h_resized, model.w_resized later. + Additionally upsamples preds to fixed resolution of (864, 1152), + can be turned off by model.upsample_preds = False + """ + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if path_to_weights is not None: + weights = torch.load(path_to_weights, map_location='cpu') + else: + weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"], + map_location='cpu') + return DKMv3(weights, 540, 720, upsample_preds = True, device=device) + +def DKMv3_indoor(path_to_weights = None, device=None): + """ + Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default + Resolution can be changed by setting model.h_resized, model.w_resized later. + """ + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if path_to_weights is not None: + weights = torch.load(path_to_weights, map_location=device) + else: + weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"], + map_location=device) + return DKMv3(weights, 480, 640, upsample_preds = False, device=device) diff --git a/imcui/third_party/DKM/dkm/train/__init__.py b/imcui/third_party/DKM/dkm/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433 --- /dev/null +++ b/imcui/third_party/DKM/dkm/train/__init__.py @@ -0,0 +1 @@ +from .train import train_k_epochs diff --git a/imcui/third_party/DKM/dkm/train/train.py b/imcui/third_party/DKM/dkm/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b580221f56a2667784836f0237955cc75131b88c --- /dev/null +++ b/imcui/third_party/DKM/dkm/train/train.py @@ -0,0 +1,67 @@ +from tqdm import tqdm +from dkm.utils.utils import to_cuda + + +def train_step(train_batch, model, objective, optimizer, **kwargs): + optimizer.zero_grad() + out = model(train_batch) + l = objective(out, train_batch) + l.backward() + optimizer.step() + return {"train_out": out, "train_loss": l.item()} + + +def train_k_steps( + n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True +): + for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar): + batch = next(dataloader) + model.train(True) + batch = to_cuda(batch) + train_step( + train_batch=batch, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + n=n, + ) + lr_scheduler.step() + + +def train_epoch( + dataloader=None, + model=None, + objective=None, + optimizer=None, + lr_scheduler=None, + epoch=None, +): + model.train(True) + print(f"At epoch {epoch}") + for batch in tqdm(dataloader, mininterval=5.0): + batch = to_cuda(batch) + train_step( + train_batch=batch, model=model, objective=objective, optimizer=optimizer + ) + lr_scheduler.step() + return { + "model": model, + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + "epoch": epoch, + } + + +def train_k_epochs( + start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler +): + for epoch in range(start_epoch, end_epoch + 1): + train_epoch( + dataloader=dataloader, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + ) diff --git a/imcui/third_party/DKM/dkm/utils/__init__.py b/imcui/third_party/DKM/dkm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05367ac9521664992f587738caa231f32ae2e81c --- /dev/null +++ b/imcui/third_party/DKM/dkm/utils/__init__.py @@ -0,0 +1,13 @@ +from .utils import ( + pose_auc, + get_pose, + compute_relative_pose, + compute_pose_error, + estimate_pose, + rotate_intrinsic, + get_tuple_transform_ops, + get_depth_tuple_transform_ops, + warp_kpts, + numpy_to_pil, + tensor_to_pil, +) diff --git a/imcui/third_party/DKM/dkm/utils/kde.py b/imcui/third_party/DKM/dkm/utils/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..fa392455e70fda4c9c77c28bda76bcb7ef9045b0 --- /dev/null +++ b/imcui/third_party/DKM/dkm/utils/kde.py @@ -0,0 +1,26 @@ +import torch +import torch.nn.functional as F +import numpy as np + +def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1): + raise NotImplementedError("WIP, use at your own risk.") + # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid + x = x.permute(0,3,1,2) + B,C,H,W = x.shape + K = kernel_size ** 2 + unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W) + scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp() + density = scores.sum(dim=1) + return density + + +def kde(x, std = 0.1, device=None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + # use a gaussian kernel to estimate density + x = x.to(device) + scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + density = scores.sum(dim=-1) + return density diff --git a/imcui/third_party/DKM/dkm/utils/local_correlation.py b/imcui/third_party/DKM/dkm/utils/local_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c1c06291d0b760376a2b2162bcf49d6eb1303c --- /dev/null +++ b/imcui/third_party/DKM/dkm/utils/local_correlation.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F + + +def local_correlation( + feature0, + feature1, + local_radius, + padding_mode="zeros", + flow = None +): + device = feature0.device + b, c, h, w = feature0.size() + if flow is None: + # If flow is None, assume feature0 and feature1 are aligned + coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + )) + coords = torch.stack((coords[1], coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + else: + coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + r = local_radius + local_window = torch.meshgrid( + ( + torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device), + torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device), + )) + local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ + None + ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2) + coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2) + window_feature = F.grid_sample( + feature1, coords, padding_mode=padding_mode, align_corners=False + )[...,None].reshape(b,c,h,w,(2*r+1)**2) + corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5) + return corr diff --git a/imcui/third_party/DKM/dkm/utils/transforms.py b/imcui/third_party/DKM/dkm/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..754d853fda4cbcf89d2111bed4f44b0ca84f0518 --- /dev/null +++ b/imcui/third_party/DKM/dkm/utils/transforms.py @@ -0,0 +1,104 @@ +from typing import Dict +import numpy as np +import torch +import kornia.augmentation as K +from kornia.geometry.transform import warp_perspective + +# Adapted from Kornia +class GeometricSequential: + def __init__(self, *transforms, align_corners=True) -> None: + self.transforms = transforms + self.align_corners = align_corners + + def __call__(self, x, mode="bilinear"): + b, c, h, w = x.shape + M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) + for t in self.transforms: + if np.random.rand() < t.p: + M = M.matmul( + t.compute_transformation(x, t.generate_parameters((b, c, h, w))) + ) + return ( + warp_perspective( + x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners + ), + M, + ) + + def apply_transform(self, x, M, mode="bilinear"): + b, c, h, w = x.shape + return warp_perspective( + x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode + ) + + +class RandomPerspective(K.RandomPerspective): + def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: + distortion_scale = torch.as_tensor( + self.distortion_scale, device=self._device, dtype=self._dtype + ) + return self.random_perspective_generator( + batch_shape[0], + batch_shape[-2], + batch_shape[-1], + distortion_scale, + self.same_on_batch, + self.device, + self.dtype, + ) + + def random_perspective_generator( + self, + batch_size: int, + height: int, + width: int, + distortion_scale: torch.Tensor, + same_on_batch: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Dict[str, torch.Tensor]: + r"""Get parameters for ``perspective`` for a random perspective transform. + + Args: + batch_size (int): the tensor batch size. + height (int) : height of the image. + width (int): width of the image. + distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. + same_on_batch (bool): apply the same transformation across the batch. Default: False. + device (torch.device): the device on which the random numbers will be generated. Default: cpu. + dtype (torch.dtype): the data type of the generated random numbers. Default: float32. + + Returns: + params Dict[str, torch.Tensor]: parameters to be passed for transformation. + - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). + - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). + + Note: + The generated random numbers are not reproducible across different devices and dtypes. + """ + if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): + raise AssertionError( + f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." + ) + if not ( + type(height) is int and height > 0 and type(width) is int and width > 0 + ): + raise AssertionError( + f"'height' and 'width' must be integers. Got {height}, {width}." + ) + + start_points: torch.Tensor = torch.tensor( + [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], + device=distortion_scale.device, + dtype=distortion_scale.dtype, + ).expand(batch_size, -1, -1) + + # generate random offset not larger than half of the image + fx = distortion_scale * width / 2 + fy = distortion_scale * height / 2 + + factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) + offset = (torch.rand_like(start_points) - 0.5) * 2 + end_points = start_points + factor * offset + + return dict(start_points=start_points, end_points=end_points) diff --git a/imcui/third_party/DKM/dkm/utils/utils.py b/imcui/third_party/DKM/dkm/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..46bbe60260930aed184c6fa5907c837c0177b304 --- /dev/null +++ b/imcui/third_party/DKM/dkm/utils/utils.py @@ -0,0 +1,341 @@ +import numpy as np +import cv2 +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torch.nn.functional as F +from PIL import Image + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + +def rotate_intrinsic(K, n): + base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + rot = np.linalg.matrix_power(base_rot, n) + return rot @ K + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + + +# From Patch2Pix https://github.com/GrumpyZhou/patch2pix +def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) + return TupleCompose(ops) + + +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize)) + if normalize: + ops.append(TupleToTensorScaled()) + ops.append( + TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) # Imagenet mean/std + else: + if unscale: + ops.append(TupleToTensorUnscaled()) + else: + ops.append(TupleToTensorScaled()) + return TupleCompose(ops) + + +class ToTensorScaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" + + def __call__(self, im): + if not isinstance(im, torch.Tensor): + im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) + im /= 255.0 + return torch.from_numpy(im) + else: + return im + + def __repr__(self): + return "ToTensorScaled(./255)" + + +class TupleToTensorScaled(object): + def __init__(self): + self.to_tensor = ToTensorScaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorScaled(./255)" + + +class ToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __call__(self, im): + return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) + + def __repr__(self): + return "ToTensorUnscaled()" + + +class TupleToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __init__(self): + self.to_tensor = ToTensorUnscaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorUnscaled()" + + +class TupleResize(object): + def __init__(self, size, mode=InterpolationMode.BICUBIC): + self.size = size + self.resize = transforms.Resize(size, mode) + + def __call__(self, im_tuple): + return [self.resize(im) for im in im_tuple] + + def __repr__(self): + return "TupleResize(size={})".format(self.size) + + +class TupleNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + self.normalize = transforms.Normalize(mean=mean, std=std) + + def __call__(self, im_tuple): + return [self.normalize(im) for im in im_tuple] + + def __repr__(self): + return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) + + +class TupleCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, im_tuple): + for t in self.transforms: + im_tuple = t(im_tuple) + return im_tuple + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode="bilinear")[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode="bilinear" + )[:, 0, :, 0] + consistent_mask = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() < 0.05 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 + + +imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) +imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + + +def numpy_to_pil(x: np.ndarray): + """ + Args: + x: Assumed to be of shape (h,w,c) + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if x.max() <= 1.01: + x *= 255 + x = x.astype(np.uint8) + return Image.fromarray(x) + + +def tensor_to_pil(x, unnormalize=False): + if unnormalize: + x = x * imagenet_std[:, None, None] + imagenet_mean[:, None, None] + x = x.detach().permute(1, 2, 0).cpu().numpy() + x = np.clip(x, 0.0, 1.0) + return numpy_to_pil(x) + + +def to_cuda(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + return batch + + +def to_cpu(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cpu() + return batch + + +def get_pose(calib): + w, h = np.array(calib["imsize"])[0] + return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w + + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans diff --git a/imcui/third_party/DKM/setup.py b/imcui/third_party/DKM/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..73ae664126066249200d72c1ec3166f4f2d76b10 --- /dev/null +++ b/imcui/third_party/DKM/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup, find_packages + +setup( + name="dkm", + packages=find_packages(include=("dkm*",)), + version="0.3.0", + author="Johan Edstedt", + install_requires=open("requirements.txt", "r").read().split("\n"), +) diff --git a/imcui/third_party/DarkFeat/configs/config.yaml b/imcui/third_party/DarkFeat/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ffead73fc3eac520aa7aa4bf3811c5069a4c149 --- /dev/null +++ b/imcui/third_party/DarkFeat/configs/config.yaml @@ -0,0 +1,24 @@ +training: + optimizer: 'SGD' + lr: 0.01 + momentum: 0.9 + weight_decay: 0.0001 + lr_gamma: 0.1 + lr_step: 200000 +network: + input_type: 'raw-demosaic' + noise: true + noise_maxstep: 1 + model: 'Quad_L2Net' + loss_type: 'HARD_CONTRASTIVE' + photaug: true + resize: 480 + use_corr_n: 512 + det: + corr_weight: true + safe_radius: 12 + kpt_n: 512 + score_thld: -1 + edge_thld: 10 + nms_size: 3 + eof_size: 5 \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/configs/config_stage1.yaml b/imcui/third_party/DarkFeat/configs/config_stage1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f94e1da377bf8f507d6fa6db394b1016227d0e25 --- /dev/null +++ b/imcui/third_party/DarkFeat/configs/config_stage1.yaml @@ -0,0 +1,24 @@ +training: + optimizer: 'SGD' + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0001 + lr_gamma: 0.1 + lr_step: 200000 +network: + input_type: 'raw-demosaic' + noise: true + noise_maxstep: 1 + model: 'Quad_L2Net' + loss_type: 'HARD_CONTRASTIVE' + photaug: true + resize: 480 + use_corr_n: 512 + det: + corr_weight: true + safe_radius: 12 + kpt_n: 512 + score_thld: -1 + edge_thld: 10 + nms_size: 3 + eof_size: 5 \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/darkfeat.py b/imcui/third_party/DarkFeat/darkfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..d146e2b41f5399ff3fc2f52ec5daff1c56e491c0 --- /dev/null +++ b/imcui/third_party/DarkFeat/darkfeat.py @@ -0,0 +1,359 @@ +import torch +from torch import nn +from torch.nn.parameter import Parameter +import torchvision.transforms as tvf +import torch.nn.functional as F +import numpy as np + + +def gather_nd(params, indices): + orig_shape = list(indices.shape) + num_samples = np.prod(orig_shape[:-1]) + m = orig_shape[-1] + n = len(params.shape) + + if m <= n: + out_shape = orig_shape[:-1] + list(params.shape)[m:] + else: + raise ValueError( + f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}' + ) + + indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist() + output = params[indices] # (num_samples, ...) + return output.reshape(out_shape).contiguous() + + +# input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W] +# output: [kpt_n, 128] / [kpt_n] +def interpolate(pos, inputs, nd=True): + h = inputs.shape[0] + w = inputs.shape[1] + + i = pos[:, 0] + j = pos[:, 1] + + i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1) + j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1) + + i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1) + j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) + + i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1) + j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1) + + i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1) + j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) + + dist_i_top_left = i - i_top_left.float() + dist_j_top_left = j - j_top_left.float() + w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) + w_top_right = (1 - dist_i_top_left) * dist_j_top_left + w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) + w_bottom_right = dist_i_top_left * dist_j_top_left + + if nd: + w_top_left = w_top_left[..., None] + w_top_right = w_top_right[..., None] + w_bottom_left = w_bottom_left[..., None] + w_bottom_right = w_bottom_right[..., None] + + interpolated_val = ( + w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + + w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + + w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + + w_bottom_right * + gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) + ) + + return interpolated_val + + +def edge_mask(inputs, n_channel, dilation=1, edge_thld=5): + b, c, h, w = inputs.size() + device = inputs.device + + dii_filter = torch.tensor( + [[0, 1., 0], [0, -2., 0], [0, 1., 0]] + ).view(1, 1, 3, 3) + dij_filter = 0.25 * torch.tensor( + [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] + ).view(1, 1, 3, 3) + djj_filter = torch.tensor( + [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] + ).view(1, 1, 3, 3) + + dii = F.conv2d( + inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation + ).view(b, c, h, w) + dij = F.conv2d( + inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation + ).view(b, c, h, w) + djj = F.conv2d( + inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation + ).view(b, c, h, w) + + det = dii * djj - dij * dij + tr = dii + djj + del dii, dij, djj + + threshold = (edge_thld + 1) ** 2 / edge_thld + is_not_edge = torch.min(tr * tr / det <= threshold, det > 0) + + return is_not_edge + + +# input: score_map [batch_size, 1, H, W] +# output: indices [2, k, 2], scores [2, k] +def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5): + h = score_map.shape[2] + w = score_map.shape[3] + + mask = score_map > score_thld + if nms_size > 0: + nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2) + nms_mask = torch.eq(score_map, nms_mask) + mask = torch.logical_and(nms_mask, mask) + if eof_size > 0: + eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device) + eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0) + eof_mask = eof_mask.bool() + mask = torch.logical_and(eof_mask, mask) + if edge_thld > 0: + non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld) + mask = torch.logical_and(non_edge_mask, mask) + + bs = score_map.shape[0] + if bs is None: + indices = torch.nonzero(mask)[0] + scores = gather_nd(score_map, indices)[0] + sample = torch.sort(scores, descending=True)[1][0:k] + indices = indices[sample].unsqueeze(0) + scores = scores[sample].unsqueeze(0) + else: + indices = [] + scores = [] + for i in range(bs): + tmp_mask = mask[i][0] + tmp_score_map = score_map[i][0] + tmp_indices = torch.nonzero(tmp_mask) + tmp_scores = gather_nd(tmp_score_map, tmp_indices) + tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k] + tmp_indices = tmp_indices[tmp_sample] + tmp_scores = tmp_scores[tmp_sample] + indices.append(tmp_indices) + scores.append(tmp_scores) + try: + indices = torch.stack(indices, dim=0) + scores = torch.stack(scores, dim=0) + except: + min_num = np.min([len(i) for i in indices]) + indices = torch.stack([i[:min_num] for i in indices], dim=0) + scores = torch.stack([i[:min_num] for i in scores], dim=0) + return indices, scores + + +# input: [batch_size, C, H, W] +# output: [batch_size, C, H, W], [batch_size, C, H, W] +def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1): + inputs = inputs / moving_instance_max + + batch_size, C, H, W = inputs.shape + + pad_size = ksize // 2 + (dilation - 1) + kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize) + + pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect') + + avg_spatial_inputs = F.conv2d( + pad_inputs, + kernel, + stride=1, + dilation=dilation, + padding=0, + groups=C + ) + avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1 + # print(avg_spatial_inputs.shape) + + alpha = F.softplus(inputs - avg_spatial_inputs) + beta = F.softplus(inputs - avg_channel_inputs) + + return alpha, beta + + +class DarkFeat(nn.Module): + default_config = { + 'model_path': '', + 'input_type': 'raw-demosaic', + 'kpt_n': 5000, + 'kpt_refinement': True, + 'score_thld': 0.5, + 'edge_thld': 10, + 'multi_scale': False, + 'multi_level': True, + 'nms_size': 3, + 'eof_size': 5, + 'need_norm': True, + 'use_peakiness': True + } + + def __init__(self, model_path='', inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): + super(DarkFeat, self).__init__() + inchan = 3 if self.default_config['input_type'] == 'rgb' or self.default_config['input_type'] == 'raw-demosaic' else 1 + self.config = {**self.default_config} + + self.inchan = inchan + self.curchan = inchan + self.dilated = dilated + self.dilation = dilation + self.bn = bn + self.bn_affine = bn_affine + self.config['model_path'] = model_path + + dim = 128 + mchan = 4 + + self.conv0 = self._add_conv( 8*mchan) + self.conv1 = self._add_conv( 8*mchan, bn=False) + self.bn1 = self._make_bn(8*mchan) + self.conv2 = self._add_conv( 16*mchan, stride=2) + self.conv3 = self._add_conv( 16*mchan, bn=False) + self.bn3 = self._make_bn(16*mchan) + self.conv4 = self._add_conv( 32*mchan, stride=2) + self.conv5 = self._add_conv( 32*mchan) + # replace last 8x8 convolution with 3 3x3 convolutions + self.conv6_0 = self._add_conv( 32*mchan) + self.conv6_1 = self._add_conv( 32*mchan) + self.conv6_2 = self._add_conv(dim, bn=False, relu=False) + self.out_dim = dim + + self.moving_avg_params = nn.ParameterList([ + Parameter(torch.tensor(1.), requires_grad=False), + Parameter(torch.tensor(1.), requires_grad=False), + Parameter(torch.tensor(1.), requires_grad=False) + ]) + self.clf = nn.Conv2d(128, 2, kernel_size=1) + + state_dict = torch.load(self.config["model_path"], map_location="cpu") + new_state_dict = {} + + for key in state_dict: + if 'running_mean' not in key and 'running_var' not in key and 'num_batches_tracked' not in key: + new_state_dict[key] = state_dict[key] + + self.load_state_dict(new_state_dict) + print('Loaded DarkFeat model') + + def _make_bn(self, outd): + return nn.BatchNorm2d(outd, affine=self.bn_affine, track_running_stats=False) + + def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False): + d = self.dilation * dilation + conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias) + + ops = nn.ModuleList([]) + + ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) ) + if bn and self.bn: ops.append( self._make_bn(outd) ) + if relu: ops.append( nn.ReLU(inplace=True) ) + self.curchan = outd + + if k_pool > 1: + if pool_type == 'avg': + ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) + elif pool_type == 'max': + ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) + else: + print(f"Error, unknown pooling type {pool_type}...") + + return nn.Sequential(*ops) + + def forward(self, input): + """ Compute keypoints, scores, descriptors for image """ + data = input['image'] + H, W = data.shape[2:] + + if self.config['input_type'] == 'rgb': + # 3-channel rgb + RGB_mean = [0.485, 0.456, 0.406] + RGB_std = [0.229, 0.224, 0.225] + norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) + data = norm_RGB(data) + + elif self.config['input_type'] == 'gray': + # 1-channel + data = torch.mean(data, dim=1, keepdim=True) + norm_gray0 = tvf.Normalize(mean=data.mean(), std=data.std()) + data = norm_gray0(data) + + elif self.config['input_type'] == 'raw': + # 4-channel + pass + elif self.config['input_type'] == 'raw-demosaic': + # 3-channel + pass + else: + raise NotImplementedError() + + # x: [N, C, H, W] + x0 = self.conv0(data) + x1 = self.conv1(x0) + x1_bn = self.bn1(x1) + x2 = self.conv2(x1_bn) + x3 = self.conv3(x2) + x3_bn = self.bn3(x3) + x4 = self.conv4(x3_bn) + x5 = self.conv5(x4) + x6_0 = self.conv6_0(x5) + x6_1 = self.conv6_1(x6_0) + x6_2 = self.conv6_2(x6_1) + + comb_weights = torch.tensor([1., 2., 3.], device=data.device) + comb_weights /= torch.sum(comb_weights) + ksize = [3, 2, 1] + det_score_maps = [] + + for idx, xx in enumerate([x1, x3, x6_2]): + alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx]) + score_vol = alpha * beta + det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0] + det_score_map = F.interpolate(det_score_map, size=data.shape[2:], mode='bilinear', align_corners=True) + det_score_map = comb_weights[idx] * det_score_map + det_score_maps.append(det_score_map) + + det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0) + + desc = x6_2 + score_map = det_score_map + conf = F.softmax(self.clf((desc)**2), dim=1)[:,1:2] + score_map = score_map * F.interpolate(conf, size=score_map.shape[2:], mode='bilinear', align_corners=True) + + kpt_inds, kpt_score = extract_kpts( + score_map, + k=self.config['kpt_n'], + score_thld=self.config['score_thld'], + nms_size=self.config['nms_size'], + eof_size=self.config['eof_size'], + edge_thld=self.config['edge_thld'] + ) + + descs = F.normalize( + interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)), + p=2, + dim=-1 + ).detach().cpu().numpy(), + kpts = np.squeeze(torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0) \ + * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32) + scores = np.squeeze(kpt_score.detach().cpu().numpy(), axis=0) + + idxs = np.negative(scores).argsort()[0:self.config['kpt_n']] + descs = descs[0][idxs] + kpts = kpts[idxs] + scores = scores[idxs] + + return { + 'keypoints': kpts, + 'scores': torch.from_numpy(scores), + 'descriptors': torch.from_numpy(descs.T), + } diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/cal_metrics.py b/imcui/third_party/DarkFeat/datasets/InvISP/cal_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3e501664487de4c08ab8c89328dd266fba2868 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/cal_metrics.py @@ -0,0 +1,114 @@ +import cv2 +import numpy as np +import math +# from skimage.metrics import structural_similarity as ssim +from skimage.measure import compare_ssim +from scipy.misc import imread +from glob import glob + +import argparse + +parser = argparse.ArgumentParser(description="evaluation codes") + +parser.add_argument("--path", type=str, help="Path to evaluate images.") + +args = parser.parse_args() + +def psnr(img1, img2): + mse = np.mean( (img1/255. - img2/255.) ** 2 ) + if mse < 1.0e-10: + return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + +def psnr_raw(img1, img2): + mse = np.mean( (img1 - img2) ** 2 ) + if mse < 1.0e-10: + return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + + +def my_ssim(img1, img2): + return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), multichannel=True) + + +def quan_eval(path, suffix="jpg"): + # path: /disk2/yazhou/projects/IISP/exps/test_final_unet_globalEDV2/ + # ours + gt_imgs = sorted(glob(path+"tar*.%s"%suffix)) + pred_imgs = sorted(glob(path+"pred*.%s"%suffix)) + + # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb: + # gt_imgs = [line.rstrip() for line in f_gt.readlines()] + # pred_imgs = [line.rstrip() for line in f_rgb.readlines()] + + assert len(gt_imgs) == len(pred_imgs) + + psnr_avg = 0. + ssim_avg = 0. + for i in range(len(gt_imgs)): + gt = imread(gt_imgs[i]) + pred = imread(pred_imgs[i]) + psnr_temp = psnr(gt, pred) + psnr_avg += psnr_temp + ssim_temp = my_ssim(gt, pred) + ssim_avg += ssim_temp + + print("psnr: ", psnr_temp) + print("ssim: ", ssim_temp) + + psnr_avg /= float(len(gt_imgs)) + ssim_avg /= float(len(gt_imgs)) + + print("psnr_avg: ", psnr_avg) + print("ssim_avg: ", ssim_avg) + + return psnr_avg, ssim_avg + +def mse(gt, pred): + return np.mean((gt-pred)**2) + +def mse_raw(path, suffix="npy"): + gt_imgs = sorted(glob(path+"raw_tar*.%s"%suffix)) + pred_imgs = sorted(glob(path+"raw_pred*.%s"%suffix)) + + # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb: + # gt_imgs = [line.rstrip() for line in f_gt.readlines()] + # pred_imgs = [line.rstrip() for line in f_rgb.readlines()] + + assert len(gt_imgs) == len(pred_imgs) + + mse_avg = 0. + psnr_avg = 0. + for i in range(len(gt_imgs)): + gt = np.load(gt_imgs[i]) + pred = np.load(pred_imgs[i]) + mse_temp = mse(gt, pred) + mse_avg += mse_temp + psnr_temp = psnr_raw(gt, pred) + psnr_avg += psnr_temp + + print("mse: ", mse_temp) + print("psnr: ", psnr_temp) + + mse_avg /= float(len(gt_imgs)) + psnr_avg /= float(len(gt_imgs)) + + print("mse_avg: ", mse_avg) + print("psnr_avg: ", psnr_avg) + + return mse_avg, psnr_avg + +test_full = False + +# if test_full: +# psnr_avg, ssim_avg = quan_eval(ROOT_PATH+"%s/vis_%s_full/"%(args.task, args.ckpt), "jpeg") +# mse_avg, psnr_avg_raw = mse_raw(ROOT_PATH+"%s/vis_%s_full/"%(args.task, args.ckpt)) +# else: +psnr_avg, ssim_avg = quan_eval(args.path, "jpg") +mse_avg, psnr_avg_raw = mse_raw(args.path) + +print("pnsr: {}, ssim: {}, mse: {}, psnr raw: {}".format(psnr_avg, ssim_avg, mse_avg, psnr_avg_raw)) + + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/config/config.py b/imcui/third_party/DarkFeat/datasets/InvISP/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..dc42182ecf7464cc85ed5c77b7aeb9ee4e3ecd74 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/config/config.py @@ -0,0 +1,21 @@ +import argparse + +BATCH_SIZE = 1 + +DATA_PATH = "./data/" + + + +def get_arguments(): + parser = argparse.ArgumentParser(description="training codes") + + parser.add_argument("--task", type=str, help="Name of this training") + parser.add_argument("--data_path", type=str, default=DATA_PATH, help="Dataset root path.") + parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size for training. ") + parser.add_argument("--debug_mode", dest='debug_mode', action='store_true', help="If debug mode, load less data.") + parser.add_argument("--gamma", dest='gamma', action='store_true', help="Use gamma compression for raw data.") + parser.add_argument("--camera", type=str, default="NIKON_D700", choices=["NIKON_D700", "Canon_EOS_5D"], help="Choose which camera to use. ") + parser.add_argument("--rgb_weight", type=float, default=1, help="Weight for rgb loss. ") + + + return parser diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py b/imcui/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..62271771a17a4863b730136d49f2a23aed0e49b2 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py @@ -0,0 +1,56 @@ +import rawpy +import numpy as np +import glob, os +import colour_demosaicing +import imageio +import argparse +from PIL import Image as PILImage +import scipy.io as scio + +parser = argparse.ArgumentParser(description="data preprocess") + +parser.add_argument("--camera", type=str, default="NIKON_D700", help="Camera Name") +parser.add_argument("--Bayer_Pattern", type=str, default="RGGB", help="Bayer Pattern of RAW") +parser.add_argument("--JPEG_Quality", type=int, default=90, help="Jpeg Quality of the ground truth.") + +args = parser.parse_args() +camera_name = args.camera +Bayer_Pattern = args.Bayer_Pattern +JPEG_Quality = args.JPEG_Quality + +dng_path = sorted(glob.glob('/mnt/nvme2n1/hyz/data/' + camera_name + '/DNG/*.cr2')) +rgb_target_path = '/mnt/nvme2n1/hyz/data/'+ camera_name + '/RGB/' +raw_input_path = '/mnt/nvme2n1/hyz/data/' + camera_name + '/RAW/' +if not os.path.isdir(rgb_target_path): + os.mkdir(rgb_target_path) +if not os.path.isdir(raw_input_path): + os.mkdir(raw_input_path) + +def flip(raw_img, flip): + if flip == 3: + raw_img = np.rot90(raw_img, k=2) + elif flip == 5: + raw_img = np.rot90(raw_img, k=1) + elif flip == 6: + raw_img = np.rot90(raw_img, k=3) + else: + pass + return raw_img + + + +for path in dng_path: + print("Start Processing %s" % os.path.basename(path)) + raw = rawpy.imread(path) + file_name = path.split('/')[-1].split('.')[0] + im = raw.postprocess(use_camera_wb=True,no_auto_bright=True) + flip_val = raw.sizes.flip + cwb = raw.camera_whitebalance + raw_img = raw.raw_image_visible + if camera_name == 'Canon_EOS_5D': + raw_img = np.maximum(raw_img - 127.0, 0) + de_raw = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw_img, Bayer_Pattern) + de_raw = flip(de_raw, flip_val) + rgb_img = PILImage.fromarray(im).save(rgb_target_path + file_name + '.jpg', quality = JPEG_Quality, subsampling = 1) + np.savez(raw_input_path + file_name + '.npz', raw=de_raw, wb=cwb) + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4c71bd3b4162bd21761983deef6b94fa46a364f6 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py @@ -0,0 +1,132 @@ +from __future__ import print_function, division +import os, random, time +import torch +import numpy as np +from torch.utils.data import Dataset +from torchvision import transforms, utils +import rawpy +from glob import glob +from PIL import Image as PILImage +import numbers +from scipy.misc import imread +from .base_dataset import BaseDataset + + +class FiveKDatasetTrain(BaseDataset): + def __init__(self, opt): + super().__init__(opt=opt) + self.patch_size = 256 + input_RAWs_WBs, target_RGBs = self.load(is_train=True) + assert len(input_RAWs_WBs) == len(target_RGBs) + self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} + + def random_flip(self, input_raw, target_rgb): + idx = np.random.randint(2) + input_raw = np.flip(input_raw,axis=idx).copy() + target_rgb = np.flip(target_rgb,axis=idx).copy() + + return input_raw, target_rgb + + def random_rotate(self, input_raw, target_rgb): + idx = np.random.randint(4) + input_raw = np.rot90(input_raw,k=idx) + target_rgb = np.rot90(target_rgb,k=idx) + + return input_raw, target_rgb + + def random_crop(self, patch_size, input_raw, target_rgb,flow=False,demos=False): + H, W, _ = input_raw.shape + rnd_h = random.randint(0, max(0, H - patch_size)) + rnd_w = random.randint(0, max(0, W - patch_size)) + + patch_input_raw = input_raw[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] + if flow or demos: + patch_target_rgb = target_rgb[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :] + else: + patch_target_rgb = target_rgb[rnd_h*2:rnd_h*2 + patch_size*2, rnd_w*2:rnd_w*2 + patch_size*2, :] + + return patch_input_raw, patch_target_rgb + + def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False): + input_raw, target_rgb = self.random_crop(patch_size, input_raw,target_rgb,flow=flow, demos=demos) + input_raw, target_rgb = self.random_rotate(input_raw,target_rgb) + input_raw, target_rgb = self.random_flip(input_raw,target_rgb) + + return input_raw, target_rgb + + def __len__(self): + return len(self.data['input_RAWs_WBs']) + + def __getitem__(self, idx): + input_raw_wb_path = self.data['input_RAWs_WBs'][idx] + target_rgb_path = self.data['target_RGBs'][idx] + + target_rgb_img = imread(target_rgb_path) + input_raw_wb = np.load(input_raw_wb_path) + input_raw_img = input_raw_wb['raw'] + wb = input_raw_wb['wb'] + wb = wb / wb.max() + input_raw_img = input_raw_img * wb[:-1] + + self.patch_size = 256 + input_raw_img, target_rgb_img = self.aug(self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True) + + if self.gamma: + norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) + input_raw_img = np.power(input_raw_img, 1/2.2) + else: + norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 + + target_rgb_img = self.norm_img(target_rgb_img, max_value=255) + input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) + target_raw_img = input_raw_img.copy() + + input_raw_img = self.np2tensor(input_raw_img).float() + target_rgb_img = self.np2tensor(target_rgb_img).float() + target_raw_img = self.np2tensor(target_raw_img).float() + + sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, + 'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} + return sample + +class FiveKDatasetTest(BaseDataset): + def __init__(self, opt): + super().__init__(opt=opt) + self.patch_size = 256 + + input_RAWs_WBs, target_RGBs = self.load(is_train=False) + assert len(input_RAWs_WBs) == len(target_RGBs) + self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs} + + def __len__(self): + return len(self.data['input_RAWs_WBs']) + + def __getitem__(self, idx): + input_raw_wb_path = self.data['input_RAWs_WBs'][idx] + target_rgb_path = self.data['target_RGBs'][idx] + + target_rgb_img = imread(target_rgb_path) + input_raw_wb = np.load(input_raw_wb_path) + input_raw_img = input_raw_wb['raw'] + wb = input_raw_wb['wb'] + wb = wb / wb.max() + input_raw_img = input_raw_img * wb[:-1] + + if self.gamma: + norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2) + input_raw_img = np.power(input_raw_img, 1/2.2) + else: + norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383 + + target_rgb_img = self.norm_img(target_rgb_img, max_value=255) + input_raw_img = self.norm_img(input_raw_img, max_value=norm_value) + target_raw_img = input_raw_img.copy() + + input_raw_img = self.np2tensor(input_raw_img).float() + target_rgb_img = self.np2tensor(target_rgb_img).float() + target_raw_img = self.np2tensor(target_raw_img).float() + + sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img, + 'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]} + return sample + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/dataset/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..34c5de9f75dbfb5323c2cdad532cb0a42c09df22 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py @@ -0,0 +1,84 @@ +from __future__ import print_function, division +import numpy as np +from torch.utils.data import Dataset +import torch + +class BaseDataset(Dataset): + def __init__(self, opt): + self.crop_size = 512 + self.debug_mode = opt.debug_mode + self.data_path = opt.data_path # dataset path. e.g., ./data/ + self.camera_name = opt.camera + self.gamma = opt.gamma + + def norm_img(self, img, max_value): + img = img / float(max_value) + return img + + def pack_raw(self, raw): + # pack Bayer image to 4 channels + im = np.expand_dims(raw, axis=2) + H, W = raw.shape[0], raw.shape[1] + # RGBG + out = np.concatenate((im[0:H:2, 0:W:2, :], + im[0:H:2, 1:W:2, :], + im[1:H:2, 1:W:2, :], + im[1:H:2, 0:W:2, :]), axis=2) + return out + + def np2tensor(self, array): + return torch.Tensor(array).permute(2,0,1) + + def center_crop(self, img, crop_size=None): + H = img.shape[0] + W = img.shape[1] + + if crop_size is not None: + th, tw = crop_size[0], crop_size[1] + else: + th, tw = self.crop_size, self.crop_size + x1_img = int(round((W - tw) / 2.)) + y1_img = int(round((H - th) / 2.)) + if img.ndim == 3: + input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw, :] + else: + input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw] + + return input_patch + + def load(self, is_train=True): + # ./data + # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB + # ./data/Canon EOS 5D/RAW, ./data/Canon EOS 5D/RGB + # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt + # ./data/NIKON D700_train.txt: a0016, ... + input_RAWs_WBs = [] + target_RGBs = [] + + data_path = self.data_path # ./data/ + if is_train: + txt_path = data_path + self.camera_name + "_train.txt" + else: + txt_path = data_path + self.camera_name + "_test.txt" + + with open(txt_path, "r") as f_read: + # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()] + valid_camera_list = [line.strip() for line in f_read.readlines()] + + if self.debug_mode: + valid_camera_list = valid_camera_list[:10] + + for i,name in enumerate(valid_camera_list): + full_name = data_path + self.camera_name + input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz") + target_RGBs.append(full_name + "/RGB/" + name + ".jpg") + + return input_RAWs_WBs, target_RGBs + + + def __len__(self): + return 0 + + def __getitem__(self, idx): + + return None diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/environment.yml b/imcui/third_party/DarkFeat/datasets/InvISP/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..20a58415354b80fb01f72fbbeb8d55edee6067ce --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/environment.yml @@ -0,0 +1,56 @@ +name: invertible-isp +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _pytorch_select=0.2=gpu_0 + - blas=1.0=mkl + - ca-certificates=2021.1.19=h06a4308_1 + - certifi=2020.12.5=py36h06a4308_0 + - cffi=1.14.5=py36h261ae71_0 + - cudatoolkit=10.1.243=h6bb024c_0 + - cudnn=7.6.5=cuda10.1_0 + - freetype=2.10.4=h5ab3b9f_0 + - intel-openmp=2020.2=254 + - jpeg=9b=h024ee3a_2 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_1 + - lz4-c=1.9.3=h2531618_0 + - mkl=2020.2=256 + - mkl-service=2.3.0=py36he8ac12f_0 + - mkl_fft=1.3.0=py36h54f3939_0 + - mkl_random=1.1.1=py36h0573a6f_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py36hff7bd54_0 + - numpy=1.19.2=py36h54aff64_0 + - numpy-base=1.19.2=py36hfa32c7d_0 + - olefile=0.46=py36_0 + - openssl=1.1.1k=h27cfd23_0 + - pillow=8.2.0=py36he98fc37_0 + - pip=21.0.1=py36h06a4308_0 + - pycparser=2.20=py_2 + - python=3.6.13=hdb3f193_0 + - pytorch=1.4.0=cuda101py36h02f0884_0 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py36h06a4308_0 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.35.3=hdfb4753_0 + - tk=8.6.10=hbc83047_0 + - torchvision=0.2.1=py36_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 + - pip: + - colour-demosaicing==0.1.6 + - colour-science==0.3.16 + - imageio==2.9.0 + - rawpy==0.16.0 + - scipy==1.2.0 + - tqdm==4.59.0 + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/loss.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..abe8b599d5402c367bb7c84b7e370964d8273518 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/loss.py @@ -0,0 +1,15 @@ +import torch.nn.functional as F +import torch + + +def l1_loss(output, target_rgb, target_raw, weight=1.): + raw_loss = F.l1_loss(output['reconstruct_raw'], target_raw) + rgb_loss = F.l1_loss(output['reconstruct_rgb'], target_rgb) + total_loss = raw_loss + weight * rgb_loss + return total_loss, raw_loss, rgb_loss + +def l2_loss(output, target_rgb, target_raw, weight=1.): + raw_loss = F.mse_loss(output['reconstruct_raw'], target_raw) + rgb_loss = F.mse_loss(output['reconstruct_rgb'], target_rgb) + total_loss = raw_loss + weight * rgb_loss + return total_loss, raw_loss, rgb_loss \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/model.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd0e33cee8ebb26d621ece84622bd2611b33a60 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/model.py @@ -0,0 +1,179 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import torch.nn.init as init + +from .modules import InvertibleConv1x1 + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def initialize_weights_xavier(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.xavier_normal_(m.weight) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +class DenseBlock(nn.Module): + def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True): + super(DenseBlock, self).__init__() + self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + if init == 'xavier': + initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) + else: + initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) + initialize_weights(self.conv5, 0) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + + return x5 + +def subnet(net_structure, init='xavier'): + def constructor(channel_in, channel_out): + if net_structure == 'DBNet': + if init == 'xavier': + return DenseBlock(channel_in, channel_out, init) + else: + return DenseBlock(channel_in, channel_out) + # return UNetBlock(channel_in, channel_out) + else: + return None + + return constructor + + +class InvBlock(nn.Module): + def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=0.8): + super(InvBlock, self).__init__() + # channel_num: 3 + # channel_split_num: 1 + + self.split_len1 = channel_split_num # 1 + self.split_len2 = channel_num - channel_split_num # 2 + + self.clamp = clamp + + self.F = subnet_constructor(self.split_len2, self.split_len1) + self.G = subnet_constructor(self.split_len1, self.split_len2) + self.H = subnet_constructor(self.split_len1, self.split_len2) + + in_channels = 3 + self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True) + self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev) + + def forward(self, x, rev=False): + if not rev: + # invert1x1conv + x, logdet = self.flow_permutation(x, logdet=0, rev=False) + + # split to 1 channel and 2 channel. + x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) + + y1 = x1 + self.F(x2) # 1 channel + self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) + y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel + out = torch.cat((y1, y2), 1) + else: + # split. + x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) + self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1) + y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) + y1 = x1 - self.F(y2) + + x = torch.cat((y1, y2), 1) + + # inv permutation + out, logdet = self.flow_permutation(x, logdet=0, rev=True) + + return out + +class InvISPNet(nn.Module): + def __init__(self, channel_in=3, channel_out=3, subnet_constructor=subnet('DBNet'), block_num=8): + super(InvISPNet, self).__init__() + operations = [] + + current_channel = channel_in + channel_num = channel_in + channel_split_num = 1 + + for j in range(block_num): + b = InvBlock(subnet_constructor, channel_num, channel_split_num) # one block is one flow step. + operations.append(b) + + self.operations = nn.ModuleList(operations) + + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + m.weight.data *= 1. # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.xavier_normal_(m.weight) + m.weight.data *= 1. + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + def forward(self, x, rev=False): + out = x # x: [N,3,H,W] + + if not rev: + for op in self.operations: + out = op.forward(out, rev) + else: + for op in reversed(self.operations): + out = op.forward(out, rev) + + return out + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/modules.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..88244c0b211860d97be78ba4f60f4743228171a7 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/modules.py @@ -0,0 +1,387 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, compute_same_pad + + +def gaussian_p(mean, logs, x): + """ + lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } + k = 1 (Independent) + Var = logs ** 2 + """ + c = math.log(2 * math.pi) + return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c) + + +def gaussian_likelihood(mean, logs, x): + p = gaussian_p(mean, logs, x) + return torch.sum(p, dim=[1, 2, 3]) + + +def gaussian_sample(mean, logs, temperature=1): + # Sample from Gaussian with temperature + z = torch.normal(mean, torch.exp(logs) * temperature) + + return z + + +def squeeze2d(input, factor): + if factor == 1: + return input + + B, C, H, W = input.size() + + assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" + + x = input.view(B, C, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(B, C * factor * factor, H // factor, W // factor) + + return x + + +def unsqueeze2d(input, factor): + if factor == 1: + return input + + factor2 = factor ** 2 + + B, C, H, W = input.size() + + assert C % (factor2) == 0, "C module factor squared is not 0" + + x = input.view(B, C // factor2, factor, factor, H, W) + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(B, C // (factor2), H * factor, W * factor) + + return x + + +class _ActNorm(nn.Module): + """ + Activation Normalization + Initialize the bias and scale with a given minibatch, + so that the output per-channel have zero mean and unit variance for that. + + After initialization, `bias` and `logs` will be trained as parameters. + """ + + def __init__(self, num_features, scale=1.0): + super().__init__() + # register mean and scale + size = [1, num_features, 1, 1] + self.bias = nn.Parameter(torch.zeros(*size)) + self.logs = nn.Parameter(torch.zeros(*size)) + self.num_features = num_features + self.scale = scale + self.inited = False + + def initialize_parameters(self, input): + if not self.training: + raise ValueError("In Eval mode, but ActNorm not inited") + + with torch.no_grad(): + bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True) + vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) + logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) + + self.bias.data.copy_(bias.data) + self.logs.data.copy_(logs.data) + + self.inited = True + + def _center(self, input, reverse=False): + if reverse: + return input - self.bias + else: + return input + self.bias + + def _scale(self, input, logdet=None, reverse=False): + + if reverse: + input = input * torch.exp(-self.logs) + else: + input = input * torch.exp(self.logs) + + if logdet is not None: + """ + logs is log_std of `mean of channels` + so we need to multiply by number of pixels + """ + b, c, h, w = input.shape + + dlogdet = torch.sum(self.logs) * h * w + + if reverse: + dlogdet *= -1 + + logdet = logdet + dlogdet + + return input, logdet + + def forward(self, input, logdet=None, reverse=False): + self._check_input_dim(input) + + if not self.inited: + self.initialize_parameters(input) + + if reverse: + input, logdet = self._scale(input, logdet, reverse) + input = self._center(input, reverse) + else: + input = self._center(input, reverse) + input, logdet = self._scale(input, logdet, reverse) + + return input, logdet + + +class ActNorm2d(_ActNorm): + def __init__(self, num_features, scale=1.0): + super().__init__(num_features, scale) + + def _check_input_dim(self, input): + assert len(input.size()) == 4 + assert input.size(1) == self.num_features, ( + "[ActNorm]: input should be in shape as `BCHW`," + " channels should be {} rather than {}".format( + self.num_features, input.size() + ) + ) + + +class LinearZeros(nn.Module): + def __init__(self, in_channels, out_channels, logscale_factor=3): + super().__init__() + + self.linear = nn.Linear(in_channels, out_channels) + self.linear.weight.data.zero_() + self.linear.bias.data.zero_() + + self.logscale_factor = logscale_factor + + self.logs = nn.Parameter(torch.zeros(out_channels)) + + def forward(self, input): + output = self.linear(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class Conv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding="same", + do_actnorm=True, + weight_std=0.05, + ): + super().__init__() + + if padding == "same": + padding = compute_same_pad(kernel_size, stride) + elif padding == "valid": + padding = 0 + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=(not do_actnorm), + ) + + # init weight with std + self.conv.weight.data.normal_(mean=0.0, std=weight_std) + + if not do_actnorm: + self.conv.bias.data.zero_() + else: + self.actnorm = ActNorm2d(out_channels) + + self.do_actnorm = do_actnorm + + def forward(self, input): + x = self.conv(input) + if self.do_actnorm: + x, _ = self.actnorm(x) + return x + + +class Conv2dZeros(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding="same", + logscale_factor=3, + ): + super().__init__() + + if padding == "same": + padding = compute_same_pad(kernel_size, stride) + elif padding == "valid": + padding = 0 + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + + self.logscale_factor = logscale_factor + self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1)) + + def forward(self, input): + output = self.conv(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class Permute2d(nn.Module): + def __init__(self, num_channels, shuffle): + super().__init__() + self.num_channels = num_channels + self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long) + self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long) + + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + + if shuffle: + self.reset_indices() + + def reset_indices(self): + shuffle_idx = torch.randperm(self.indices.shape[0]) + self.indices = self.indices[shuffle_idx] + + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + + def forward(self, input, reverse=False): + assert len(input.size()) == 4 + + if not reverse: + input = input[:, self.indices, :, :] + return input + else: + return input[:, self.indices_inverse, :, :] + + +class Split2d(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.conv = Conv2dZeros(num_channels // 2, num_channels) + + def split2d_prior(self, z): + h = self.conv(z) + return split_feature(h, "cross") + + def forward(self, input, logdet=0.0, reverse=False, temperature=None): + if reverse: + z1 = input + mean, logs = self.split2d_prior(z1) + z2 = gaussian_sample(mean, logs, temperature) + z = torch.cat((z1, z2), dim=1) + return z, logdet + else: + z1, z2 = split_feature(input, "split") + mean, logs = self.split2d_prior(z1) + logdet = gaussian_likelihood(mean, logs, z2) + logdet + return z1, logdet + + +class SqueezeLayer(nn.Module): + def __init__(self, factor): + super().__init__() + self.factor = factor + + def forward(self, input, logdet=None, reverse=False): + if reverse: + output = unsqueeze2d(input, self.factor) + else: + output = squeeze2d(input, self.factor) + + return output, logdet + + +class InvertibleConv1x1(nn.Module): + def __init__(self, num_channels, LU_decomposed): + super().__init__() + w_shape = [num_channels, num_channels] + w_init = torch.linalg.qr(torch.randn(*w_shape))[0] + + if not LU_decomposed: + self.weight = nn.Parameter(torch.Tensor(w_init)) + else: + p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) + s = torch.diag(upper) + sign_s = torch.sign(s) + log_s = torch.log(torch.abs(s)) + upper = torch.triu(upper, 1) + l_mask = torch.tril(torch.ones(w_shape), -1) + eye = torch.eye(*w_shape) + + self.register_buffer("p", p) + self.register_buffer("sign_s", sign_s) + self.lower = nn.Parameter(lower) + self.log_s = nn.Parameter(log_s) + self.upper = nn.Parameter(upper) + self.l_mask = l_mask + self.eye = eye + + self.w_shape = w_shape + self.LU_decomposed = LU_decomposed + + def get_weight(self, input, reverse): + b, c, h, w = input.shape + + if not self.LU_decomposed: + dlogdet = torch.slogdet(self.weight)[1] * h * w + if reverse: + weight = torch.inverse(self.weight) + else: + weight = self.weight + else: + self.l_mask = self.l_mask.to(input.device) + self.eye = self.eye.to(input.device) + + lower = self.lower * self.l_mask + self.eye + + u = self.upper * self.l_mask.transpose(0, 1).contiguous() + u += torch.diag(self.sign_s * torch.exp(self.log_s)) + + dlogdet = torch.sum(self.log_s) * h * w + + if reverse: + u_inv = torch.inverse(u) + l_inv = torch.inverse(lower) + p_inv = torch.inverse(self.p) + + weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) + else: + weight = torch.matmul(self.p, torch.matmul(lower, u)) + + return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet + + def forward(self, input, logdet=None, reverse=False): + """ + log-det = log|abs(|W|)| * pixels + """ + weight, dlogdet = self.get_weight(input, reverse) + + if not reverse: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet + dlogdet + return z, logdet + else: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet - dlogdet + return z, logdet diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/utils.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1bef31afd7d61d4c942ffd895c818b90571b4b7 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/utils.py @@ -0,0 +1,52 @@ +import math +import torch + + +def compute_same_pad(kernel_size, stride): + if isinstance(kernel_size, int): + kernel_size = [kernel_size] + + if isinstance(stride, int): + stride = [stride] + + assert len(stride) == len( + kernel_size + ), "Pass kernel size and stride both as int, or both as equal length iterable" + + return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] + + +def uniform_binning_correction(x, n_bits=8): + """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). + + Args: + x: 4-D Tensor of shape (NCHW) + n_bits: optional. + Returns: + x: x ~ U(x, x + 1.0 / 256) + objective: Equivalent to -q(x)*log(q(x)). + """ + b, c, h, w = x.size() + n_bins = 2 ** n_bits + chw = c * h * w + x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) + + objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) + return x, objective + + +def split_feature(tensor, type="split"): + """ + type = ["split", "cross"] + """ + C = tensor.size(1) + if type == "split": + # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] + return tensor[:, :1, ...], tensor[:,1:, ...] + elif type == "cross": + # return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + + + + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/test_raw.py b/imcui/third_party/DarkFeat/datasets/InvISP/test_raw.py new file mode 100644 index 0000000000000000000000000000000000000000..37610f8268e4586864e0275236c5bb1932f894df --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/test_raw.py @@ -0,0 +1,118 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torch +import numpy as np +import os, time, random +import argparse +from torch.utils.data import Dataset, DataLoader +from PIL import Image as PILImage +from glob import glob +from tqdm import tqdm + +from model.model import InvISPNet +from dataset.FiveK_dataset import FiveKDatasetTest +from config.config import get_arguments + +from utils.JPEG import DiffJPEG +from utils.commons import denorm, preprocess_test_patch + + +os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') +os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) +# os.environ['CUDA_VISIBLE_DEVICES'] = '7' +os.system('rm tmp') + +DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() + +parser = get_arguments() +parser.add_argument("--ckpt", type=str, help="Checkpoint path.") +parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ") +parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ") +args = parser.parse_args() +print("Parsed arguments: {}".format(args)) + + +ckpt_name = args.ckpt.split("/")[-1].split(".")[0] +if args.split_to_patch: + os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True) + out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name) +else: + os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True) + out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name) + + +def main(args): + # ======================================define the model============================================ + net = InvISPNet(channel_in=3, channel_out=3, block_num=8) + device = torch.device("cuda:0") + + net.to(device) + net.eval() + # load the pretrained weight if there exists one + if os.path.isfile(args.ckpt): + net.load_state_dict(torch.load(args.ckpt), strict=False) + print("[INFO] Loaded checkpoint: {}".format(args.ckpt)) + + print("[INFO] Start data load and preprocessing") + RAWDataset = FiveKDatasetTest(opt=args) + dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True) + + input_RGBs = sorted(glob(out_path+"pred*jpg")) + input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs] + + print("[INFO] Start test...") + for i_batch, sample_batched in enumerate(tqdm(dataloader)): + step_time = time.time() + + input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \ + sample_batched['target_raw'].to(device) + file_name = sample_batched['file_name'][0] + + if args.split_to_patch: + input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw) + else: + # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution + input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]] + + for i_patch in range(len(input_list)): + file_name_patch = file_name + "_%05d"%i_patch + idx = input_RGBs_names.index(file_name_patch) + input_RGB_path = input_RGBs[idx] + input_RGB = torch.from_numpy(np.array(PILImage.open(input_RGB_path))/255.0).unsqueeze(0).permute(0,3,1,2).float().to(device) + + target_raw_patch = target_raw_list[i_patch] + + with torch.no_grad(): + reconstruct_raw = net(input_RGB, rev=True) + + pred_raw = reconstruct_raw.detach().permute(0,2,3,1) + pred_raw = torch.clamp(pred_raw, 0, 1) + + target_raw_patch = target_raw_patch.permute(0,2,3,1) + pred_raw = denorm(pred_raw, 255) + target_raw_patch = denorm(target_raw_patch, 255) + + pred_raw = pred_raw.cpu().numpy() + target_raw_patch = target_raw_patch.cpu().numpy().astype(np.float32) + + raw_pred = PILImage.fromarray(np.uint8(pred_raw[0,:,:,0])) + raw_tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_raw_patch[0,:,:,0]), np.uint8(pred_raw[0,:,:,0])))) + + raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0,:,:,0])) + + raw_pred.save(out_path+"raw_pred_%s_%05d.jpg"%(file_name, i_patch)) + raw_tar.save(out_path+"raw_tar_%s_%05d.jpg"%(file_name, i_patch)) + raw_tar_pred.save(out_path+"raw_gt_pred_%s_%05d.jpg"%(file_name, i_patch)) + + np.save(out_path+"raw_pred_%s_%05d.npy"%(file_name, i_patch), pred_raw[0,:,:,:]/255.0) + np.save(out_path+"raw_tar_%s_%05d.npy"%(file_name, i_patch), target_raw_patch[0,:,:,:]/255.0) + + del reconstruct_raw + + +if __name__ == '__main__': + + torch.set_num_threads(4) + main(args) + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/test_rgb.py b/imcui/third_party/DarkFeat/datasets/InvISP/test_rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e054b899d9142609e3f90f4a12d367a45aeac0 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/test_rgb.py @@ -0,0 +1,105 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torch +import numpy as np +import os, time, random +import argparse +from torch.utils.data import Dataset, DataLoader +from PIL import Image as PILImage + +from model.model import InvISPNet +from dataset.FiveK_dataset import FiveKDatasetTest +from config.config import get_arguments + +from utils.JPEG import DiffJPEG +from utils.commons import denorm, preprocess_test_patch +from tqdm import tqdm + +os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') +os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) +# os.environ['CUDA_VISIBLE_DEVICES'] = '7' +os.system('rm tmp') + +DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() + +parser = get_arguments() +parser.add_argument("--ckpt", type=str, help="Checkpoint path.") +parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save results. ") +parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ") +args = parser.parse_args() +print("Parsed arguments: {}".format(args)) + + +ckpt_name = args.ckpt.split("/")[-1].split(".")[0] +if args.split_to_patch: + os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True) + out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name) +else: + os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True) + out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name) + + +def main(args): + # ======================================define the model============================================ + net = InvISPNet(channel_in=3, channel_out=3, block_num=8) + device = torch.device("cuda:0") + + net.to(device) + net.eval() + # load the pretrained weight if there exists one + if os.path.isfile(args.ckpt): + net.load_state_dict(torch.load(args.ckpt), strict=False) + print("[INFO] Loaded checkpoint: {}".format(args.ckpt)) + + print("[INFO] Start data load and preprocessing") + RAWDataset = FiveKDatasetTest(opt=args) + dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True) + + print("[INFO] Start test...") + for i_batch, sample_batched in enumerate(tqdm(dataloader)): + step_time = time.time() + + input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \ + sample_batched['target_raw'].to(device) + file_name = sample_batched['file_name'][0] + + if args.split_to_patch: + input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw) + else: + # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution + input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]] + + for i_patch in range(len(input_list)): + input_patch = input_list[i_patch] + target_rgb_patch = target_rgb_list[i_patch] + target_raw_patch = target_raw_list[i_patch] + + with torch.no_grad(): + reconstruct_rgb = net(input_patch) + reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) + + pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1) + target_rgb_patch = target_rgb_patch.permute(0,2,3,1) + + pred_rgb = denorm(pred_rgb, 255) + target_rgb_patch = denorm(target_rgb_patch, 255) + pred_rgb = pred_rgb.cpu().numpy() + target_rgb_patch = target_rgb_patch.cpu().numpy().astype(np.float32) + + # print(type(pred_rgb)) + pred = PILImage.fromarray(np.uint8(pred_rgb[0,:,:,:])) + tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_rgb_patch[0,:,:,:]), np.uint8(pred_rgb[0,:,:,:])))) + + tar = PILImage.fromarray(np.uint8(target_rgb_patch[0,:,:,:])) + + pred.save(out_path+"pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1) + tar.save(out_path+"tar_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1) + tar_pred.save(out_path+"gt_pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1) + + del reconstruct_rgb + +if __name__ == '__main__': + torch.set_num_threads(4) + main(args) + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/train.py b/imcui/third_party/DarkFeat/datasets/InvISP/train.py new file mode 100644 index 0000000000000000000000000000000000000000..16186cb38d825ac1299e5c4164799d35bfa79907 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/train.py @@ -0,0 +1,98 @@ +import numpy as np +import os, time, random +import argparse +import json + +import torch.nn.functional as F +import torch +from torch.utils.data import Dataset, DataLoader +from torch.optim import lr_scheduler + +from model.model import InvISPNet +from dataset.FiveK_dataset import FiveKDatasetTrain +from config.config import get_arguments + +from utils.JPEG import DiffJPEG + +os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') +os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) +# os.environ['CUDA_VISIBLE_DEVICES'] = "1" +os.system('rm tmp') + +DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() + +parser = get_arguments() +parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ") +parser.add_argument("--resume", dest='resume', action='store_true', help="Resume training. ") +parser.add_argument("--loss", type=str, default="L1", choices=["L1", "L2"], help="Choose which loss function to use. ") +parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate") +parser.add_argument("--aug", dest='aug', action='store_true', help="Use data augmentation.") +args = parser.parse_args() +print("Parsed arguments: {}".format(args)) + +os.makedirs(args.out_path, exist_ok=True) +os.makedirs(args.out_path+"%s"%args.task, exist_ok=True) +os.makedirs(args.out_path+"%s/checkpoint"%args.task, exist_ok=True) + +with open(args.out_path+"%s/commandline_args.yaml"%args.task , 'w') as f: + json.dump(args.__dict__, f, indent=2) + +def main(args): + # ======================================define the model====================================== + net = InvISPNet(channel_in=3, channel_out=3, block_num=8) + net.cuda() + # load the pretrained weight if there exists one + if args.resume: + net.load_state_dict(torch.load(args.out_path+"%s/checkpoint/latest.pth"%args.task)) + print("[INFO] loaded " + args.out_path+"%s/checkpoint/latest.pth"%args.task) + + optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) + scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.5) + + print("[INFO] Start data loading and preprocessing") + RAWDataset = FiveKDatasetTrain(opt=args) + dataloader = DataLoader(RAWDataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) + + print("[INFO] Start to train") + step = 0 + for epoch in range(0, 300): + epoch_time = time.time() + + for i_batch, sample_batched in enumerate(dataloader): + step_time = time.time() + + input, target_rgb, target_raw = sample_batched['input_raw'].cuda(), sample_batched['target_rgb'].cuda(), \ + sample_batched['target_raw'].cuda() + + reconstruct_rgb = net(input) + reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) + rgb_loss = F.l1_loss(reconstruct_rgb, target_rgb) + reconstruct_rgb = DiffJPEG(reconstruct_rgb) + reconstruct_raw = net(reconstruct_rgb, rev=True) + raw_loss = F.l1_loss(reconstruct_raw, target_raw) + + loss = args.rgb_weight * rgb_loss + raw_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print("task: %s Epoch: %d Step: %d || loss: %.5f raw_loss: %.5f rgb_loss: %.5f || lr: %f time: %f"%( + args.task, epoch, step, loss.detach().cpu().numpy(), raw_loss.detach().cpu().numpy(), + rgb_loss.detach().cpu().numpy(), optimizer.param_groups[0]['lr'], time.time()-step_time + )) + step += 1 + + torch.save(net.state_dict(), args.out_path+"%s/checkpoint/latest.pth"%args.task) + if (epoch+1) % 10 == 0: + # os.makedirs(args.out_path+"%s/checkpoint/%04d"%(args.task,epoch), exist_ok=True) + torch.save(net.state_dict(), args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch)) + print("[INFO] Successfully saved "+args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch)) + scheduler.step() + + print("[INFO] Epoch time: ", time.time()-epoch_time, "task: ", args.task) + +if __name__ == '__main__': + + torch.set_num_threads(4) + main(args) diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py new file mode 100644 index 0000000000000000000000000000000000000000..8997ee98a41668b4737a9b2acc2341032f173bd3 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py @@ -0,0 +1,43 @@ + + +import torch +import torch.nn as nn + +from .JPEG_utils import diff_round, quality_to_factor, Quantization +from .compression import compress_jpeg +from .decompression import decompress_jpeg + + +class DiffJPEG(nn.Module): + def __init__(self, differentiable=True, quality=75): + ''' Initialize the DiffJPEG layer + Inputs: + height(int): Original image height + width(int): Original image width + differentiable(bool): If true uses custom differentiable + rounding function, if false uses standrard torch.round + quality(float): Quality factor for jpeg compression scheme. + ''' + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + # rounding = Quantization() + else: + rounding = torch.round + factor = quality_to_factor(quality) + self.compress = compress_jpeg(rounding=rounding, factor=factor) + # self.decompress = decompress_jpeg(height, width, rounding=rounding, + # factor=factor) + self.decompress = decompress_jpeg(rounding=rounding, factor=factor) + + def forward(self, x): + ''' + ''' + org_height = x.shape[2] + org_width = x.shape[3] + y, cb, cr = self.compress(x) + + recovered = self.decompress(y, cb, cr, org_height, org_width) + return recovered + + diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ebd9bdc184e869ade58eea1c6763baa1d9fc91 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py @@ -0,0 +1,75 @@ +# Standard libraries +import numpy as np +# PyTorch +import torch +import torch.nn as nn +import math + +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, + 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, + 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T + +y_table = nn.Parameter(torch.from_numpy(y_table)) +# +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], + [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round_back(x): + """ Differentiable rounding function + Input: + x(tensor) + Output: + x(tensor) + """ + return torch.round(x) + (x - torch.round(x))**3 + + + +def diff_round(input_tensor): + test = 0 + for n in range(1, 10): + test += math.pow(-1, n+1) / n * torch.sin(2 * math.pi * n * input_tensor) + final_tensor = input_tensor - 1 / math.pi * test + return final_tensor + + +class Quant(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + input = torch.clamp(input, 0, 1) + output = (input * 255.).round() / 255. + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class Quantization(nn.Module): + def __init__(self): + super(Quantization, self).__init__() + + def forward(self, input): + return Quant.apply(input) + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + Input: + quality(float): Quality for jpeg compression + Output: + factor(float): Compression factor + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality*2 + return quality / 100. \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/commons.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..e594e0597bac601edc2015d9cae670799f981495 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/commons.py @@ -0,0 +1,23 @@ +import numpy as np + + +def denorm(img, max_value): + img = img * float(max_value) + return img + +def preprocess_test_patch(input_image, target_image, gt_image): + input_patch_list = [] + target_patch_list = [] + gt_patch_list = [] + H = input_image.shape[2] + W = input_image.shape[3] + for i in range(3): + for j in range(3): + input_patch = input_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)] + target_patch = target_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)] + gt_patch = gt_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)] + input_patch_list.append(input_patch) + target_patch_list.append(target_patch) + gt_patch_list.append(gt_patch) + + return input_patch_list, target_patch_list, gt_patch_list diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/compression.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae22f8839517bfd7e3c774528943e8fff59dce7 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/compression.py @@ -0,0 +1,185 @@ +# Standard libraries +import itertools +import numpy as np +# PyTorch +import torch +import torch.nn as nn +# Local +from . import JPEG_utils + + +class rgb_to_ycbcr_jpeg(nn.Module): + """ Converts RGB image to YCbCr + Input: + image(tensor): batch x 3 x height x width + Outpput: + result(tensor): batch x height x width x 3 + """ + def __init__(self): + super(rgb_to_ycbcr_jpeg, self).__init__() + matrix = np.array( + [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], + [0.5, -0.418688, -0.081312]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + # + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + # result = torch.from_numpy(result) + result.view(image.shape) + return result + + + +class chroma_subsampling(nn.Module): + """ Chroma subsampling on CbCv channels + Input: + image(tensor): batch x height x width x 3 + Output: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + def __init__(self): + super(chroma_subsampling, self).__init__() + + def forward(self, image): + image_2 = image.permute(0, 3, 1, 2).clone() + avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), + count_include_pad=False) + cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1)) + cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1)) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class block_splitting(nn.Module): + """ Splitting image into patches + Input: + image(tensor): batch x height x width + Output: + patch(tensor): batch x h*w/64 x h x w + """ + def __init__(self): + super(block_splitting, self).__init__() + self.k = 8 + + def forward(self, image): + height, width = image.shape[1:3] + # print(height, width) + batch_size = image.shape[0] + # print(image.shape) + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class dct_8x8(nn.Module): + """ Discrete Cosine Transformation + Input: + image(tensor): batch x height x width + Output: + dcp(tensor): batch x height x width + """ + def __init__(self): + super(dct_8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos( + (2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + # + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() ) + + def forward(self, image): + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class y_quantize(nn.Module): + """ JPEG Quantization for Y channel + Input: + image(tensor): batch x height x width + rounding(function): rounding function to use + factor(float): Degree of compression + Output: + image(tensor): batch x height x width + """ + def __init__(self, rounding, factor=1): + super(y_quantize, self).__init__() + self.rounding = rounding + self.factor = factor + self.y_table = JPEG_utils.y_table + + def forward(self, image): + image = image.float() / (self.y_table * self.factor) + image = self.rounding(image) + return image + + +class c_quantize(nn.Module): + """ JPEG Quantization for CrCb channels + Input: + image(tensor): batch x height x width + rounding(function): rounding function to use + factor(float): Degree of compression + Output: + image(tensor): batch x height x width + """ + def __init__(self, rounding, factor=1): + super(c_quantize, self).__init__() + self.rounding = rounding + self.factor = factor + self.c_table = JPEG_utils.c_table + + def forward(self, image): + image = image.float() / (self.c_table * self.factor) + image = self.rounding(image) + return image + + +class compress_jpeg(nn.Module): + """ Full JPEG compression algortihm + Input: + imgs(tensor): batch x 3 x height x width + rounding(function): rounding function to use + factor(float): Compression factor + Ouput: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + """ + def __init__(self, rounding=torch.round, factor=1): + super(compress_jpeg, self).__init__() + self.l1 = nn.Sequential( + rgb_to_ycbcr_jpeg(), + # comment this line if no subsampling + chroma_subsampling() + ) + self.l2 = nn.Sequential( + block_splitting(), + dct_8x8() + ) + self.c_quantize = c_quantize(rounding=rounding, factor=factor) + self.y_quantize = y_quantize(rounding=rounding, factor=factor) + + def forward(self, image): + y, cb, cr = self.l1(image*255) # modify + + # y, cb, cr = result[:,:,:,0], result[:,:,:,1], result[:,:,:,2] + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + # print(comp.shape) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp) + else: + comp = self.y_quantize(comp) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/decompression.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/decompression.py new file mode 100644 index 0000000000000000000000000000000000000000..b73ff96d5f6818e1d0464b9c4133f559a3b23fba --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/decompression.py @@ -0,0 +1,190 @@ +# Standard libraries +import itertools +import numpy as np +# PyTorch +import torch +import torch.nn as nn +# Local +from . import JPEG_utils as utils + + +class y_dequantize(nn.Module): + """ Dequantize Y channel + Inputs: + image(tensor): batch x height x width + factor(float): compression factor + Outputs: + image(tensor): batch x height x width + """ + def __init__(self, factor=1): + super(y_dequantize, self).__init__() + self.y_table = utils.y_table + self.factor = factor + + def forward(self, image): + return image * (self.y_table * self.factor) + + +class c_dequantize(nn.Module): + """ Dequantize CbCr channel + Inputs: + image(tensor): batch x height x width + factor(float): compression factor + Outputs: + image(tensor): batch x height x width + """ + def __init__(self, factor=1): + super(c_dequantize, self).__init__() + self.factor = factor + self.c_table = utils.c_table + + def forward(self, image): + return image * (self.c_table * self.factor) + + +class idct_8x8(nn.Module): + """ Inverse discrete Cosine Transformation + Input: + dcp(tensor): batch x height x width + Output: + image(tensor): batch x height x width + """ + def __init__(self): + super(idct_8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( + (2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class block_merging(nn.Module): + """ Merge pathces into image + Inputs: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + Output: + image(tensor): batch x height x width + """ + def __init__(self): + super(block_merging, self).__init__() + + def forward(self, patches, height, width): + k = 8 + batch_size = patches.shape[0] + # print(patches.shape) # (1,1024,8,8) + image_reshaped = patches.view(batch_size, height//k, width//k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class chroma_upsampling(nn.Module): + """ Upsample chroma layers + Input: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + Ouput: + image(tensor): batch x height x width x 3 + """ + def __init__(self): + super(chroma_upsampling, self).__init__() + + def forward(self, y, cb, cr): + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class ycbcr_to_rgb_jpeg(nn.Module): + """ Converts YCbCr image to RGB JPEG + Input: + image(tensor): batch x height x width x 3 + Outpput: + result(tensor): batch x 3 x height x width + """ + def __init__(self): + super(ycbcr_to_rgb_jpeg, self).__init__() + + matrix = np.array( + [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + #result = torch.from_numpy(result) + result.view(image.shape) + return result.permute(0, 3, 1, 2) + + +class decompress_jpeg(nn.Module): + """ Full JPEG decompression algortihm + Input: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + rounding(function): rounding function to use + factor(float): Compression factor + Ouput: + image(tensor): batch x 3 x height x width + """ + # def __init__(self, height, width, rounding=torch.round, factor=1): + def __init__(self, rounding=torch.round, factor=1): + super(decompress_jpeg, self).__init__() + self.c_dequantize = c_dequantize(factor=factor) + self.y_dequantize = y_dequantize(factor=factor) + self.idct = idct_8x8() + self.merging = block_merging() + # comment this line if no subsampling + self.chroma = chroma_upsampling() + self.colors = ycbcr_to_rgb_jpeg() + + # self.height, self.width = height, width + + def forward(self, y, cb, cr, height, width): + components = {'y': y, 'cb': cb, 'cr': cr} + # height = y.shape[0] + # width = y.shape[1] + self.height = height + self.width = width + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k]) + # comment this line if no subsampling + height, width = int(self.height/2), int(self.width/2) + # height, width = int(self.height), int(self.width) + + else: + comp = self.y_dequantize(components[k]) + # comment this line if no subsampling + height, width = self.height, self.width + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + # comment this line if no subsampling + image = self.chroma(components['y'], components['cb'], components['cr']) + # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3) + image = self.colors(image) + + image = torch.min(255*torch.ones_like(image), + torch.max(torch.zeros_like(image), image)) + return image/255 + diff --git a/imcui/third_party/DarkFeat/datasets/__init__.py b/imcui/third_party/DarkFeat/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DarkFeat/datasets/gl3d/io.py b/imcui/third_party/DarkFeat/datasets/gl3d/io.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5b4b0459d6814ef6af17a0a322b59202037d4f --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/gl3d/io.py @@ -0,0 +1,76 @@ +import os +import re +import cv2 +import numpy as np + +from ..utils.common import Notify + +def read_list(list_path): + """Read list.""" + if list_path is None or not os.path.exists(list_path): + print(Notify.FAIL, 'Not exist', list_path, Notify.ENDC) + exit(-1) + content = open(list_path).read().splitlines() + return content + + +def load_pfm(pfm_path): + with open(pfm_path, 'rb') as fin: + color = None + width = None + height = None + scale = None + data_type = None + header = str(fin.readline().decode('UTF-8')).rstrip() + + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', + fin.readline().decode('UTF-8')) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + scale = float((fin.readline().decode('UTF-8')).rstrip()) + if scale < 0: # little-endian + data_type = ' 0: + img = cv2.resize( + img, (config['resize'], config['resize'])) + return img + + +def _parse_depth(depth_paths, idx, config): + depth = load_pfm(depth_paths[idx]) + + if config['resize'] > 0: + target_size = config['resize'] + if config['input_type'] == 'raw': + depth = cv2.resize(depth, (int(target_size/2), int(target_size/2))) + else: + depth = cv2.resize(depth, (target_size, target_size)) + return depth + + +def _parse_kpts(kpts_paths, idx, config): + kpts = np.load(kpts_paths[idx])['pts'] + # output: [N, 2] (W first H last) + return kpts diff --git a/imcui/third_party/DarkFeat/datasets/gl3d_dataset.py b/imcui/third_party/DarkFeat/datasets/gl3d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..db3d2db646ae7fce81424f5f72cdff7e6e34ba60 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/gl3d_dataset.py @@ -0,0 +1,127 @@ +import os +import numpy as np +import torch +from torch.utils.data import Dataset +from random import shuffle, seed + +from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts +from .utils.common import Notify +from .utils.photaug import photaug + + +class GL3DDataset(Dataset): + def __init__(self, dataset_dir, config, data_split, is_training): + self.dataset_dir = dataset_dir + self.config = config + self.is_training = is_training + self.data_split = data_split + + self.match_set_list, self.global_img_list, \ + self.global_depth_list = self.prepare_match_sets() + + pass + + + def __len__(self): + return len(self.match_set_list) + + + def __getitem__(self, idx): + match_set_path = self.match_set_list[idx] + decoded = np.fromfile(match_set_path, dtype=np.float32) + + idx0, idx1 = int(decoded[0]), int(decoded[1]) + inlier_num = int(decoded[2]) + ori_img_size0 = np.reshape(decoded[3:5], (2,)) + ori_img_size1 = np.reshape(decoded[5:7], (2,)) + K0 = np.reshape(decoded[7:16], (3, 3)) + K1 = np.reshape(decoded[16:25], (3, 3)) + rel_pose = np.reshape(decoded[34:46], (3, 4)) + + # parse images. + img0 = _parse_img(self.global_img_list, idx0, self.config) + img1 = _parse_img(self.global_img_list, idx1, self.config) + # parse depths + depth0 = _parse_depth(self.global_depth_list, idx0, self.config) + depth1 = _parse_depth(self.global_depth_list, idx1, self.config) + + # photometric augmentation + img0 = photaug(img0) + img1 = photaug(img1) + + return { + 'img0': img0 / 255., + 'img1': img1 / 255., + 'depth0': depth0, + 'depth1': depth1, + 'ori_img_size0': ori_img_size0, + 'ori_img_size1': ori_img_size1, + 'K0': K0, + 'K1': K1, + 'rel_pose': rel_pose, + 'inlier_num': inlier_num + } + + + def points_to_2D(self, pnts, H, W): + labels = np.zeros((H, W)) + pnts = pnts.astype(int) + labels[pnts[:, 1], pnts[:, 0]] = 1 + return labels + + + def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60): + """Get match sets. + Args: + is_training: Use training imageset or testing imageset. + data_split: Data split name. + Returns: + match_set_list: List of match sets path. + global_img_list: List of global image path. + global_context_feat_list: + """ + # get necessary lists. + gl3d_list_folder = os.path.join(self.dataset_dir, 'list', self.data_split) + global_info = read_list(os.path.join( + gl3d_list_folder, 'image_index_offset.txt')) + global_img_list = [os.path.join(self.dataset_dir, i) for i in read_list( + os.path.join(gl3d_list_folder, 'image_list.txt'))] + global_depth_list = [os.path.join(self.dataset_dir, i) for i in read_list( + os.path.join(gl3d_list_folder, 'depth_list.txt'))] + + imageset_list_name = 'imageset_train.txt' if self.is_training else 'imageset_test.txt' + match_set_list = self.get_match_set_list(os.path.join( + gl3d_list_folder, imageset_list_name), q_diff_thld, rot_diff_thld) + return match_set_list, global_img_list, global_depth_list + + + def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld): + """Get the path list of match sets. + Args: + imageset_list_path: Path to imageset list. + q_diff_thld: Threshold of image pair sampling regarding camera orientation. + Returns: + match_set_list: List of match set path. + """ + imageset_list = [os.path.join(self.dataset_dir, 'data', i) + for i in read_list(imageset_list_path)] + print(Notify.INFO, 'Use # imageset', len(imageset_list), Notify.ENDC) + match_set_list = [] + # discard image pairs whose image simiarity is beyond the threshold. + for i in imageset_list: + match_set_folder = os.path.join(i, 'match_sets') + if os.path.exists(match_set_folder): + match_set_files = os.listdir(match_set_folder) + for val in match_set_files: + name, ext = os.path.splitext(val) + if ext == '.match_set': + splits = name.split('_') + q_diff = int(splits[2]) + rot_diff = int(splits[3]) + if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld: + match_set_list.append( + os.path.join(match_set_folder, val)) + + print(Notify.INFO, 'Get # match sets', len(match_set_list), Notify.ENDC) + return match_set_list + diff --git a/imcui/third_party/DarkFeat/datasets/noise.py b/imcui/third_party/DarkFeat/datasets/noise.py new file mode 100644 index 0000000000000000000000000000000000000000..aa68c98183186e9e9185e78e1a3e7335ac8d5bb1 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/noise.py @@ -0,0 +1,82 @@ +import numpy as np +import random +from scipy.stats import tukeylambda + +camera_params = { + 'Kmin': 0.2181895124454343, + 'Kmax': 3.0, + 'G_shape': np.array([0.15714286, 0.14285714, 0.08571429, 0.08571429, 0.2 , + 0.2 , 0.1 , 0.08571429, 0.05714286, 0.07142857, + 0.02857143, 0.02857143, 0.01428571, 0.02857143, 0.08571429, + 0.07142857, 0.11428571, 0.11428571]), + 'Profile-1': { + 'R_scale': { + 'slope': 0.4712797750747537, + 'bias': -0.8078958947116487, + 'sigma': 0.2436176299944695 + }, + 'g_scale': { + 'slope': 0.6771267783987617, + 'bias': 1.5121876510805845, + 'sigma': 0.24641096601611254 + }, + 'G_scale': { + 'slope': 0.6558756156508007, + 'bias': 1.09268679594838, + 'sigma': 0.28604721742277756 + } + }, + 'black_level': 2048, + 'max_value': 16383 +} + + +# photon shot noise +def addPStarNoise(img, K): + return np.random.poisson(img / K).astype(np.float32) * K + + +# read noise +# tukey lambda distribution +def addGStarNoise(img, K, G_shape, G_scale_param): + # sample a shape parameter [lambda] from histogram of samples + a, b = np.histogram(G_shape, bins=10, range=(-0.25, 0.25)) + a, b = np.array(a), np.array(b) + a = a / a.sum() + + rand_num = random.uniform(0, 1) + idx = np.sum(np.cumsum(a) < rand_num) + lam = random.uniform(b[idx], b[idx+1]) + + # calculate scale parameter [G_scale] + log_K = np.log(K) + log_G_scale = np.random.standard_normal() * G_scale_param['sigma'] * 1 +\ + G_scale_param['slope'] * log_K + G_scale_param['bias'] + G_scale = np.exp(log_G_scale) + # print(f'G_scale: {G_scale}') + + return img + tukeylambda.rvs(lam, scale=G_scale, size=img.shape).astype(np.float32) + + +# row noise +# uniform distribution for each row +def addRowNoise(img, K, R_scale_param): + # calculate scale parameter [R_scale] + log_K = np.log(K) + log_R_scale = np.random.standard_normal() * R_scale_param['sigma'] * 1 +\ + R_scale_param['slope'] * log_K + R_scale_param['bias'] + R_scale = np.exp(log_R_scale) + # print(f'R_scale: {R_scale}') + + row_noise = np.random.randn(img.shape[0], 1).astype(np.float32) * R_scale + return img + np.tile(row_noise, (1, img.shape[1])) + + +# quantization noise +# uniform distribution +def addQuantNoise(img, q): + return img + np.random.uniform(low=-0.5*q, high=0.5*q, size=img.shape) + + +def sampleK(Kmin, Kmax): + return np.exp(np.random.uniform(low=np.log(Kmin), high=np.log(Kmax))) diff --git a/imcui/third_party/DarkFeat/datasets/noise_simulator.py b/imcui/third_party/DarkFeat/datasets/noise_simulator.py new file mode 100644 index 0000000000000000000000000000000000000000..17e21d3b3443aaa3585ae8460709f60b05835a84 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/noise_simulator.py @@ -0,0 +1,244 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torch +import numpy as np +import os, time, random +import argparse +from torch.utils.data import Dataset, DataLoader +from PIL import Image as PILImage +from glob import glob +from tqdm import tqdm +import rawpy +import colour_demosaicing + +from .InvISP.model.model import InvISPNet +from .utils.common import Notify +from datasets.noise import camera_params, addGStarNoise, addPStarNoise, addQuantNoise, addRowNoise, sampleK + + +class NoiseSimulator: + def __init__(self, device, ckpt_path='./datasets/InvISP/pretrained/canon.pth'): + self.device = device + + # load Invertible ISP Network + self.net = InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval() + self.net.load_state_dict(torch.load(ckpt_path), strict=False) + print(Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC) + + # white balance parameters + self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0]) + + # use Canon EOS 5D4 noise parameters provided by ELD + self.camera_params = camera_params + + # random specify exposure time ratio from 50 to 150 + self.ratio_min = 50 + self.ratio_max = 150 + pass + + # inverse demosaic + # input: [H, W, 3] + # output: [H, W] + def invDemosaic(self, img): + img_R = img[::2, ::2, 0] + img_G1 = img[::2, 1::2, 1] + img_G2 = img[1::2, ::2, 1] + img_B = img[1::2, 1::2, 2] + raw_img = np.ones(img.shape[:2]) + raw_img[::2, ::2] = img_R + raw_img[::2, 1::2] = img_G1 + raw_img[1::2, ::2] = img_G2 + raw_img[1::2, 1::2] = img_B + return raw_img + + # demosaic - nearest ver + # input: [H, W] + # output: [H, W, 3] + def demosaicNearest(self, img): + raw = np.ones((img.shape[0], img.shape[1], 3)) + raw[::2, ::2, 0] = img[::2, ::2] + raw[::2, 1::2, 0] = img[::2, ::2] + raw[1::2, ::2, 0] = img[::2, ::2] + raw[1::2, 1::2, 0] = img[::2, ::2] + raw[::2, ::2, 2] = img[1::2, 1::2] + raw[::2, 1::2, 2] = img[1::2, 1::2] + raw[1::2, ::2, 2] = img[1::2, 1::2] + raw[1::2, 1::2, 2] = img[1::2, 1::2] + raw[::2, ::2, 1] = img[::2, 1::2] + raw[::2, 1::2, 1] = img[::2, 1::2] + raw[1::2, ::2, 1] = img[1::2, ::2] + raw[1::2, 1::2, 1] = img[1::2, ::2] + return raw + + # demosaic + # input: [H, W] + # output: [H, W, 3] + def demosaic(self, img): + return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, 'RGGB') + + # load rgb image + def path2rgb(self, path): + return torch.from_numpy(np.array(PILImage.open(path))/255.0) + + # InvISP + # input: rgb image [H, W, 3] + # output: raw image [H, W] + def rgb2raw(self, rgb, batched=False): + # 1. rgb -> invnet + if not batched: + rgb = rgb.unsqueeze(0) + + rgb = rgb.permute(0,3,1,2).float().to(self.device) + with torch.no_grad(): + reconstruct_raw = self.net(rgb, rev=True) + + pred_raw = reconstruct_raw.detach().permute(0,2,3,1) + pred_raw = torch.clamp(pred_raw, 0, 1) + + if not batched: + pred_raw = pred_raw[0, ...] + + pred_raw = pred_raw.cpu().numpy() + + # 2. -> inv gamma + norm_value = np.power(16383, 1/2.2) + pred_raw *= norm_value + pred_raw = np.power(pred_raw, 2.2) + + # 3. -> inv white balance + wb = self.wb / self.wb.max() + pred_raw = pred_raw / wb[:-1] + + # 4. -> add black level + pred_raw += self.camera_params['black_level'] + + # 5. -> inv demosaic + if not batched: + pred_raw = self.invDemosaic(pred_raw) + else: + preds = [] + for i in range(pred_raw.shape[0]): + preds.append(self.invDemosaic(pred_raw[i])) + pred_raw = np.stack(preds, axis=0) + + return pred_raw + + + def raw2noisyRaw(self, raw, ratio_dec=1, batched=False): + if not batched: + ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1 + raw = raw.copy() / ratio + + K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) + q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) + + raw = addPStarNoise(raw, K) + raw = addGStarNoise(raw, K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) + raw = addRowNoise(raw, K, self.camera_params['Profile-1']['R_scale']) + raw = addQuantNoise(raw, q) + raw *= ratio + return raw + + else: + raw = raw.copy() + for i in range(raw.shape[0]): + ratio = random.uniform(self.ratio_min, self.ratio_max) + raw[i] /= ratio + + K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax']) + q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level']) + + raw[i] = addPStarNoise(raw[i], K) + raw[i] = addGStarNoise(raw[i], K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale']) + raw[i] = addRowNoise(raw[i], K, self.camera_params['Profile-1']['R_scale']) + raw[i] = addQuantNoise(raw[i], q) + raw[i] *= ratio + return raw + + def raw2rgb(self, raw, batched=False): + # 1. -> demosaic + if not batched: + raw = self.demosaic(raw) + else: + raws = [] + for i in range(raw.shape[0]): + raws.append(self.demosaic(raw[i])) + raw = np.stack(raws, axis=0) + + # 2. -> substract black level + raw -= self.camera_params['black_level'] + raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) + + # 3. -> white balance + wb = self.wb / self.wb.max() + raw = raw * wb[:-1] + + # 4. -> gamma + norm_value = np.power(16383, 1/2.2) + raw = np.power(raw, 1/2.2) + raw /= norm_value + + # 5. -> ispnet + if not batched: + input_raw_img = torch.Tensor(raw).permute(2,0,1).float().to(self.device)[np.newaxis, ...] + else: + input_raw_img = torch.Tensor(raw).permute(0,3,1,2).float().to(self.device) + + with torch.no_grad(): + reconstruct_rgb = self.net(input_raw_img) + reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) + + pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1) + + if not batched: + pred_rgb = pred_rgb[0, ...] + pred_rgb = pred_rgb.cpu().numpy() + + return pred_rgb + + + def raw2packedRaw(self, raw, batched=False): + # 1. -> substract black level + raw -= self.camera_params['black_level'] + raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) + raw /= self.camera_params['max_value'] + + # 2. pack + if not batched: + im = np.expand_dims(raw, axis=2) + img_shape = im.shape + H = img_shape[0] + W = img_shape[1] + + out = np.concatenate((im[0:H:2, 0:W:2, :], + im[0:H:2, 1:W:2, :], + im[1:H:2, 1:W:2, :], + im[1:H:2, 0:W:2, :]), axis=2) + else: + im = np.expand_dims(raw, axis=3) + img_shape = im.shape + H = img_shape[1] + W = img_shape[2] + + out = np.concatenate((im[:, 0:H:2, 0:W:2, :], + im[:, 0:H:2, 1:W:2, :], + im[:, 1:H:2, 1:W:2, :], + im[:, 1:H:2, 0:W:2, :]), axis=3) + return out + + def raw2demosaicRaw(self, raw, batched=False): + # 1. -> demosaic + if not batched: + raw = self.demosaic(raw) + else: + raws = [] + for i in range(raw.shape[0]): + raws.append(self.demosaic(raw[i])) + raw = np.stack(raws, axis=0) + + # 2. -> substract black level + raw -= self.camera_params['black_level'] + raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level']) + raw /= self.camera_params['max_value'] + return raw diff --git a/imcui/third_party/DarkFeat/datasets/utils/common.py b/imcui/third_party/DarkFeat/datasets/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..6433408a39e53fcedb634901268754ed1ba971b3 --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/utils/common.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +""" +Copyright 2017, Zixin Luo, HKUST. +Commonly used functions +""" + +from __future__ import print_function + +import os +from datetime import datetime + + +class ClassProperty(property): + """For dynamically obtaining system time""" + + def __get__(self, cls, owner): + return classmethod(self.fget).__get__(None, owner)() + + +class Notify(object): + """Colorful printing prefix. + A quick example: + print(Notify.INFO, YOUR TEXT, Notify.ENDC) + """ + + def __init__(self): + pass + + @ClassProperty + def HEADER(cls): + return str(datetime.now()) + ': \033[95m' + + @ClassProperty + def INFO(cls): + return str(datetime.now()) + ': \033[92mI' + + @ClassProperty + def OKBLUE(cls): + return str(datetime.now()) + ': \033[94m' + + @ClassProperty + def WARNING(cls): + return str(datetime.now()) + ': \033[93mW' + + @ClassProperty + def FAIL(cls): + return str(datetime.now()) + ': \033[91mF' + + @ClassProperty + def BOLD(cls): + return str(datetime.now()) + ': \033[1mB' + + @ClassProperty + def UNDERLINE(cls): + return str(datetime.now()) + ': \033[4mU' + ENDC = '\033[0m' + + diff --git a/imcui/third_party/DarkFeat/datasets/utils/photaug.py b/imcui/third_party/DarkFeat/datasets/utils/photaug.py new file mode 100644 index 0000000000000000000000000000000000000000..41f2278c720355470f00a881a1516cf1b71d2c4a --- /dev/null +++ b/imcui/third_party/DarkFeat/datasets/utils/photaug.py @@ -0,0 +1,50 @@ +import cv2 +import numpy as np +import random + + +def random_brightness_np(image, max_abs_change=50): + delta = random.uniform(-max_abs_change, max_abs_change) + return np.clip(image + delta, 0, 255) + +def random_contrast_np(image, strength_range=[0.3, 1.5]): + delta = random.uniform(*strength_range) + mean = image.mean() + return np.clip((image - mean) * delta + mean, 0, 255) + +def motion_blur_np(img, max_kernel_size=3): + # Either vertial, hozirontal or diagonal blur + mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) + ksize = np.random.randint( + 0, (max_kernel_size+1)/2)*2 + 1 # make sure is odd + center = int((ksize-1)/2) + kernel = np.zeros((ksize, ksize)) + if mode == 'h': + kernel[center, :] = 1. + elif mode == 'v': + kernel[:, center] = 1. + elif mode == 'diag_down': + kernel = np.eye(ksize) + elif mode == 'diag_up': + kernel = np.flip(np.eye(ksize), 0) + var = ksize * ksize / 16. + grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) + gaussian = np.exp(-(np.square(grid-center) + + np.square(grid.T-center))/(2.*var)) + kernel *= gaussian + kernel /= np.sum(kernel) + img = cv2.filter2D(img, -1, kernel) + return np.clip(img, 0, 255) + +def additive_gaussian_noise(image, stddev_range=[5, 95]): + stddev = random.uniform(*stddev_range) + noise = np.random.normal(size=image.shape, scale=stddev) + noisy_image = np.clip(image + noise, 0, 255) + return noisy_image + +def photaug(img): + img = random_brightness_np(img) + img = random_contrast_np(img) + # img = additive_gaussian_noise(img) + img = motion_blur_np(img) + return img diff --git a/imcui/third_party/DarkFeat/demo_darkfeat.py b/imcui/third_party/DarkFeat/demo_darkfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..ca50ae5b892e7a90e75da7197c33bc0c06e699bf --- /dev/null +++ b/imcui/third_party/DarkFeat/demo_darkfeat.py @@ -0,0 +1,124 @@ +from pathlib import Path +import argparse +import cv2 +import matplotlib.cm as cm +import torch +import numpy as np +from utils.nnmatching import NNMatching +from utils.misc import (AverageTimer, VideoStreamer, make_matching_plot_fast, frame2tensor) + +torch.set_grad_enabled(False) + + +def compute_essential(matched_kp1, matched_kp2, K): + pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + K_1 = np.eye(3) + # Estimate the homography between the matches using RANSAC + ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000) + if ransac_inliers is None or ransac_model.shape != (3,3): + ransac_inliers = np.array([]) + ransac_model = None + return ransac_model, ransac_inliers, pts1, pts2 + + +sizer = (960, 640) +focallength_x = 4.504986436499113e+03/(6744/sizer[0]) +focallength_y = 4.513311442889859e+03/(4502/sizer[1]) +K = np.eye(3) +K[0,0] = focallength_x +K[1,1] = focallength_y +K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5 +K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='DarkFeat demo', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--input', type=str, + help='path to an image directory') + parser.add_argument( + '--output_dir', type=str, default=None, + help='Directory where to write output frames (If None, no output)') + + parser.add_argument( + '--image_glob', type=str, nargs='+', default=['*.ARW'], + help='Glob if a directory of images is specified') + parser.add_argument( + '--resize', type=int, nargs='+', default=[640, 480], + help='Resize the input image before running inference. If two numbers, ' + 'resize to the exact dimensions, if one number, resize the max ' + 'dimension, if -1, do not resize') + parser.add_argument( + '--force_cpu', action='store_true', + help='Force pytorch to run in CPU mode.') + parser.add_argument('--model_path', type=str, + help='Path to the pretrained model') + + opt = parser.parse_args() + print(opt) + + assert len(opt.resize) == 2 + print('Will resize to {}x{} (WxH)'.format(opt.resize[0], opt.resize[1])) + + device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu' + print('Running inference on device \"{}\"'.format(device)) + matching = NNMatching(opt.model_path).eval().to(device) + keys = ['keypoints', 'scores', 'descriptors'] + + vs = VideoStreamer(opt.input, opt.resize, opt.image_glob) + frame, ret = vs.next_frame() + assert ret, 'Error when reading the first frame (try different --input?)' + + frame_tensor = frame2tensor(frame, device) + last_data = matching.darkfeat({'image': frame_tensor}) + last_data = {k+'0': [last_data[k]] for k in keys} + last_data['image0'] = frame_tensor + last_frame = frame + last_image_id = 0 + + if opt.output_dir is not None: + print('==> Will write outputs to {}'.format(opt.output_dir)) + Path(opt.output_dir).mkdir(exist_ok=True) + + timer = AverageTimer() + + while True: + frame, ret = vs.next_frame() + if not ret: + print('Finished demo_darkfeat.py') + break + timer.update('data') + stem0, stem1 = last_image_id, vs.i - 1 + + frame_tensor = frame2tensor(frame, device) + pred = matching({**last_data, 'image1': frame_tensor}) + kpts0 = last_data['keypoints0'][0].cpu().numpy() + kpts1 = pred['keypoints1'][0].cpu().numpy() + matches = pred['matches0'][0].cpu().numpy() + confidence = pred['matching_scores0'][0].cpu().numpy() + timer.update('forward') + + valid = matches > -1 + mkpts0 = kpts0[valid] + mkpts1 = kpts1[matches[valid]] + + E, inliers, pts1, pts2 = compute_essential(mkpts0, mkpts1, K) + color = cm.jet(np.clip(confidence[valid][inliers[:, 0].astype('bool')] * 2 - 1, -1, 1)) + + text = [ + 'DarkFeat', + 'Matches: {}'.format(inliers.sum()) + ] + + out = make_matching_plot_fast( + last_frame, frame, mkpts0[inliers[:, 0].astype('bool')], mkpts1[inliers[:, 0].astype('bool')], color, text, + path=None, small_text=' ') + + if opt.output_dir is not None: + stem = 'matches_{:06}_{:06}'.format(stem0, stem1) + out_file = str(Path(opt.output_dir, stem + '.png')) + print('Writing image to {}'.format(out_file)) + cv2.imwrite(out_file, out) diff --git a/imcui/third_party/DarkFeat/export_features.py b/imcui/third_party/DarkFeat/export_features.py new file mode 100644 index 0000000000000000000000000000000000000000..c7caea5e57890948728f84cbb7e68e59d455e171 --- /dev/null +++ b/imcui/third_party/DarkFeat/export_features.py @@ -0,0 +1,128 @@ +import argparse +import glob +import math +import subprocess +import numpy as np +import os +import tqdm +import torch +import torch.nn as nn +import cv2 +from darkfeat import DarkFeat +from utils import matching + +def darkfeat_pre(img, cuda): + H, W = img.shape[0], img.shape[1] + inp = img.copy() + inp = inp.transpose(2, 0, 1) + inp = torch.from_numpy(inp) + inp = torch.autograd.Variable(inp).view(1, 3, H, W) + if cuda: + inp = inp.cuda() + return inp + +if __name__ == '__main__': + # Parse command line arguments. + parser = argparse.ArgumentParser() + parser.add_argument('--H', type=int, default=int(640)) + parser.add_argument('--W', type=int, default=int(960)) + parser.add_argument('--histeq', action='store_true') + parser.add_argument('--model_path', type=str) + parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/') + opt = parser.parse_args() + + sizer = (opt.W, opt.H) + focallength_x = 4.504986436499113e+03/(6744/sizer[0]) + focallength_y = 4.513311442889859e+03/(4502/sizer[1]) + K = np.eye(3) + K[0,0] = focallength_x + K[1,1] = focallength_y + K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5 + K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5 + Kinv = np.linalg.inv(K) + Kinvt = np.transpose(Kinv) + + cuda = True + if cuda: + darkfeat = DarkFeat(opt.model_path).cuda().eval() + + for scene in ['Indoor', 'Outdoor']: + base_save = './result/' + scene + '/' + dir_base = opt.dataset_dir + '/' + scene + '/' + pair_list = sorted(os.listdir(dir_base)) + + for pair in tqdm.tqdm(pair_list): + opention = 1 + if scene == 'Outdoor': + pass + else: + if int(pair[4::]) <= 17: + opention = 0 + else: + pass + name=[] + files = sorted(os.listdir(dir_base+pair)) + for file_ in files: + if file_.endswith('.cr2'): + name.append(file_[0:9]) + ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800'] + if opention == 1: + Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5'] + else: + Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1'] + + E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy') + F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv) + R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy') + t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy') + + id0, id1 = sorted([ int(i.split('/')[-1]) for i in glob.glob(f'{dir_base+pair}/?????') ]) + + cnt = 0 + + for iso in ISO: + for ex in Shutter_speed: + dark_name1 = name[0] + iso+'_'+ex+'_'+scene+'.npy' + dark_name2 = name[1] + iso+'_'+ex+'_'+scene+'.npy' + + if not opt.histeq: + dst_T1_None = f'{dir_base}{pair}/{id0:05d}-npy-nohisteq/{dark_name1}' + dst_T2_None = f'{dir_base}{pair}/{id1:05d}-npy-nohisteq/{dark_name2}' + + img1_orig_None = np.load(dst_T1_None) + img2_orig_None = np.load(dst_T2_None) + + dir_save = base_save + pair + '/None/' + + img_input1 = darkfeat_pre(img1_orig_None.astype('float32')/255.0, cuda) + img_input2 = darkfeat_pre(img2_orig_None.astype('float32')/255.0, cuda) + + else: + dst_T1_histeq = f'{dir_base}{pair}/{id0:05d}-npy/{dark_name1}' + dst_T2_histeq = f'{dir_base}{pair}/{id1:05d}-npy/{dark_name2}' + + img1_orig_histeq = np.load(dst_T1_histeq) + img2_orig_histeq = np.load(dst_T2_histeq) + + dir_save = base_save + pair + '/HistEQ/' + + img_input1 = darkfeat_pre(img1_orig_histeq.astype('float32')/255.0, cuda) + img_input2 = darkfeat_pre(img2_orig_histeq.astype('float32')/255.0, cuda) + + result1 = darkfeat({'image': img_input1}) + result2 = darkfeat({'image': img_input2}) + + mkpts0, mkpts1, _ = matching.match_descriptors( + cv2.KeyPoint_convert(result1['keypoints'].detach().cpu().float().numpy()), result1['descriptors'].detach().cpu().numpy(), + cv2.KeyPoint_convert(result2['keypoints'].detach().cpu().float().numpy()), result2['descriptors'].detach().cpu().numpy(), + ORB=False + ) + + POINT_1_dir = dir_save+f'DarkFeat/POINT_1/' + POINT_2_dir = dir_save+f'DarkFeat/POINT_2/' + + subprocess.check_output(['mkdir', '-p', POINT_1_dir]) + subprocess.check_output(['mkdir', '-p', POINT_2_dir]) + np.save(POINT_1_dir+dark_name1[0:-3]+'npy',mkpts0) + np.save(POINT_2_dir+dark_name2[0:-3]+'npy',mkpts1) + diff --git a/imcui/third_party/DarkFeat/nets/__init__.py b/imcui/third_party/DarkFeat/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DarkFeat/nets/geom.py b/imcui/third_party/DarkFeat/nets/geom.py new file mode 100644 index 0000000000000000000000000000000000000000..043ca6e8f5917c56defd6aa17c1ff236a431f8c0 --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/geom.py @@ -0,0 +1,323 @@ +import time +import numpy as np +import torch +import torch.nn.functional as F + + +def rnd_sample(inputs, n_sample): + cur_size = inputs[0].shape[0] + rnd_idx = torch.randperm(cur_size)[0:n_sample] + outputs = [i[rnd_idx] for i in inputs] + return outputs + + +def _grid_positions(h, w, bs): + x_rng = torch.arange(0, w.int()) + y_rng = torch.arange(0, h.int()) + xv, yv = torch.meshgrid(x_rng, y_rng, indexing='xy') + return torch.reshape( + torch.stack((yv, xv), axis=-1), + (1, -1, 2) + ).repeat(bs, 1, 1).float() + + +def getK(ori_img_size, cur_feat_size, K): + # WARNING: cur_feat_size's order is [h, w] + r = ori_img_size / cur_feat_size[[1, 0]] + r_K0 = torch.stack([K[:, 0] / r[:, 0][..., None], K[:, 1] / + r[:, 1][..., None], K[:, 2]], axis=1) + return r_K0 + + +def gather_nd(params, indices): + """ The same as tf.gather_nd but batched gather is not supported yet. + indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params: + + output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]] + + Args: + params (Tensor): "n" dimensions. shape: [x_0, x_1, x_2, ..., x_{n-1}] + indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n. + + Returns: gathered Tensor. + shape [y_0,y_2,...y_{k-2}] + params.shape[m:] + + """ + orig_shape = list(indices.shape) + num_samples = np.prod(orig_shape[:-1]) + m = orig_shape[-1] + n = len(params.shape) + + if m <= n: + out_shape = orig_shape[:-1] + list(params.shape)[m:] + else: + raise ValueError( + f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}' + ) + + indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist() + output = params[indices] # (num_samples, ...) + return output.reshape(out_shape).contiguous() + +# input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W] +# output: [kpt_n, 128] / [kpt_n] +def interpolate(pos, inputs, nd=True): + h = inputs.shape[0] + w = inputs.shape[1] + + i = pos[:, 0] + j = pos[:, 1] + + i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1) + j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1) + + i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1) + j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) + + i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1) + j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1) + + i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1) + j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) + + dist_i_top_left = i - i_top_left.float() + dist_j_top_left = j - j_top_left.float() + w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) + w_top_right = (1 - dist_i_top_left) * dist_j_top_left + w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) + w_bottom_right = dist_i_top_left * dist_j_top_left + + if nd: + w_top_left = w_top_left[..., None] + w_top_right = w_top_right[..., None] + w_bottom_left = w_bottom_left[..., None] + w_bottom_right = w_bottom_right[..., None] + + interpolated_val = ( + w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + + w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + + w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + + w_bottom_right * + gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) + ) + + return interpolated_val + + +def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=None, nd=False): + if nd: + h, w, c = inputs.shape + else: + h, w = inputs.shape + ids = torch.arange(0, pos.shape[0]) + + i = pos[:, 0] + j = pos[:, 1] + + i_top_left = torch.floor(i).int() + j_top_left = torch.floor(j).int() + + i_top_right = torch.floor(i).int() + j_top_right = torch.ceil(j).int() + + i_bottom_left = torch.ceil(i).int() + j_bottom_left = torch.floor(j).int() + + i_bottom_right = torch.ceil(i).int() + j_bottom_right = torch.ceil(j).int() + + if validate_corner: + # Valid corner + valid_top_left = torch.logical_and(i_top_left >= 0, j_top_left >= 0) + valid_top_right = torch.logical_and(i_top_right >= 0, j_top_right < w) + valid_bottom_left = torch.logical_and(i_bottom_left < h, j_bottom_left >= 0) + valid_bottom_right = torch.logical_and(i_bottom_right < h, j_bottom_right < w) + + valid_corner = torch.logical_and( + torch.logical_and(valid_top_left, valid_top_right), + torch.logical_and(valid_bottom_left, valid_bottom_right) + ) + + i_top_left = i_top_left[valid_corner] + j_top_left = j_top_left[valid_corner] + + i_top_right = i_top_right[valid_corner] + j_top_right = j_top_right[valid_corner] + + i_bottom_left = i_bottom_left[valid_corner] + j_bottom_left = j_bottom_left[valid_corner] + + i_bottom_right = i_bottom_right[valid_corner] + j_bottom_right = j_bottom_right[valid_corner] + + ids = ids[valid_corner] + + if validate_val is not None: + # Valid depth + valid_depth = torch.logical_and( + torch.logical_and( + gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) > 0, + gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0 + ), + torch.logical_and( + gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) > 0, + gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) > 0 + ) + ) + + i_top_left = i_top_left[valid_depth] + j_top_left = j_top_left[valid_depth] + + i_top_right = i_top_right[valid_depth] + j_top_right = j_top_right[valid_depth] + + i_bottom_left = i_bottom_left[valid_depth] + j_bottom_left = j_bottom_left[valid_depth] + + i_bottom_right = i_bottom_right[valid_depth] + j_bottom_right = j_bottom_right[valid_depth] + + ids = ids[valid_depth] + + # Interpolation + i = i[ids] + j = j[ids] + dist_i_top_left = i - i_top_left.float() + dist_j_top_left = j - j_top_left.float() + w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) + w_top_right = (1 - dist_i_top_left) * dist_j_top_left + w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) + w_bottom_right = dist_i_top_left * dist_j_top_left + + if nd: + w_top_left = w_top_left[..., None] + w_top_right = w_top_right[..., None] + w_bottom_left = w_bottom_left[..., None] + w_bottom_right = w_bottom_right[..., None] + + interpolated_val = ( + w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) + + w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) + + w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) + + w_bottom_right * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) + ) + + pos = torch.stack([i, j], axis=1) + return [interpolated_val, pos, ids] + + +# pos0: [2, 230400, 2] +# depth0: [2, 480, 480] +def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs): + def swap_axis(data): + return torch.stack([data[:, 1], data[:, 0]], axis=-1) + + all_pos0 = [] + all_pos1 = [] + all_ids = [] + for i in range(bs): + z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) + + uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1) + xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) + xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, + torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0) + + xyz1 = torch.matmul(rel_pose[i], xyz0_homo) + xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) + uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2] + + new_pos1 = swap_axis(uv1) + annotated_depth, new_pos1, new_ids = validate_and_interpolate( + new_pos1, depth1[i], validate_val=0) + + ids = ids[new_ids] + new_pos0 = new_pos0[new_ids] + estimated_depth = xyz1.t()[new_ids][:, -1] + + inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05 + + all_ids.append(ids[inlier_mask]) + all_pos0.append(new_pos0[inlier_mask]) + all_pos1.append(new_pos1[inlier_mask]) + # all_pos0 & all_pose1: [inlier_num, 2] * batch_size + return all_pos0, all_pos1, all_ids + + +# pos0: [2, 230400, 2] +# depth0: [2, 480, 480] +def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs): + def swap_axis(data): + return torch.stack([data[:, 1], data[:, 0]], axis=-1) + + all_pos0 = [] + all_pos1 = [] + all_ids = [] + for i in range(bs): + z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0) + + uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1) + xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t()) + xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, + torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0) + + xyz1 = torch.matmul(rel_pose[i], xyz0_homo) + xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) + uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2] + + new_pos1 = swap_axis(uv1) + _, new_pos1, new_ids = validate_and_interpolate( + new_pos1, depth1[i], validate_val=0) + + ids = ids[new_ids] + new_pos0 = new_pos0[new_ids] + + all_ids.append(ids) + all_pos0.append(new_pos0) + all_pos1.append(new_pos1) + # all_pos0 & all_pose1: [inlier_num, 2] * batch_size + return all_pos0, all_pos1, all_ids + + +# pos0: [2, 230400, 2] +# depth0: [2, 480, 480] +def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1): + def swap_axis(data): + return torch.stack([data[:, 1], data[:, 0]], axis=-1) + + z0 = interpolate(pos0, depth0, nd=False) + + uv0_homo = torch.cat([swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1) + xy0_homo = torch.matmul(torch.linalg.inv(K0), uv0_homo.t()) + xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo, + torch.ones((1, pos0.shape[0])).to(z0.device)], axis=0) + + xyz1 = torch.matmul(rel_pose, xyz0_homo) + xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0) + uv1 = torch.matmul(K1, xy1_homo).t()[:, 0:2] + + new_pos1 = swap_axis(uv1) + + return new_pos1 + + + +def get_dist_mat(feat1, feat2, dist_type): + eps = 1e-6 + cos_dist_mat = torch.matmul(feat1, feat2.t()) + if dist_type == 'cosine_dist': + dist_mat = torch.clamp(cos_dist_mat, -1, 1) + elif dist_type == 'euclidean_dist': + dist_mat = torch.sqrt(torch.clamp(2 - 2 * cos_dist_mat, min=eps)) + elif dist_type == 'euclidean_dist_no_norm': + norm1 = torch.sum(feat1 * feat1, axis=-1, keepdims=True) + norm2 = torch.sum(feat2 * feat2, axis=-1, keepdims=True) + dist_mat = torch.sqrt( + torch.clamp( + norm1 - 2 * cos_dist_mat + norm2.t(), + min=0. + ) + eps + ) + else: + raise NotImplementedError() + return dist_mat diff --git a/imcui/third_party/DarkFeat/nets/l2net.py b/imcui/third_party/DarkFeat/nets/l2net.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ddfe8919bd4d5fe75215d253525123e1402952 --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/l2net.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from .score import peakiness_score + + +class BaseNet(nn.Module): + """ Helper class to construct a fully-convolutional network that + extract a l2-normalized patch descriptor. + """ + def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): + super(BaseNet, self).__init__() + self.inchan = inchan + self.curchan = inchan + self.dilated = dilated + self.dilation = dilation + self.bn = bn + self.bn_affine = bn_affine + + def _make_bn(self, outd): + return nn.BatchNorm2d(outd, affine=self.bn_affine) + + def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False): + # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer + d = self.dilation * dilation + # if self.dilated: + # conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1) + # self.dilation *= stride + # else: + # conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride) + conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias) + + ops = nn.ModuleList([]) + + ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) ) + if bn and self.bn: ops.append( self._make_bn(outd) ) + if relu: ops.append( nn.ReLU(inplace=True) ) + self.curchan = outd + + if k_pool > 1: + if pool_type == 'avg': + ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) + elif pool_type == 'max': + ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) + else: + print(f"Error, unknown pooling type {pool_type}...") + + return nn.Sequential(*ops) + + +class Quad_L2Net(BaseNet): + """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs. + """ + def __init__(self, dim=128, mchan=4, relu22=False, **kw): + BaseNet.__init__(self, **kw) + self.conv0 = self._add_conv( 8*mchan) + self.conv1 = self._add_conv( 8*mchan, bn=False) + self.bn1 = self._make_bn(8*mchan) + self.conv2 = self._add_conv( 16*mchan, stride=2) + self.conv3 = self._add_conv( 16*mchan, bn=False) + self.bn3 = self._make_bn(16*mchan) + self.conv4 = self._add_conv( 32*mchan, stride=2) + self.conv5 = self._add_conv( 32*mchan) + # replace last 8x8 convolution with 3 3x3 convolutions + self.conv6_0 = self._add_conv( 32*mchan) + self.conv6_1 = self._add_conv( 32*mchan) + self.conv6_2 = self._add_conv(dim, bn=False, relu=False) + self.out_dim = dim + + self.moving_avg_params = nn.ParameterList([ + Parameter(torch.tensor(1.), requires_grad=False), + Parameter(torch.tensor(1.), requires_grad=False), + Parameter(torch.tensor(1.), requires_grad=False) + ]) + + def forward(self, x): + # x: [N, C, H, W] + x0 = self.conv0(x) + x1 = self.conv1(x0) + x1_bn = self.bn1(x1) + x2 = self.conv2(x1_bn) + x3 = self.conv3(x2) + x3_bn = self.bn3(x3) + x4 = self.conv4(x3_bn) + x5 = self.conv5(x4) + x6_0 = self.conv6_0(x5) + x6_1 = self.conv6_1(x6_0) + x6_2 = self.conv6_2(x6_1) + + # calculate score map + comb_weights = torch.tensor([1., 2., 3.], device=x.device) + comb_weights /= torch.sum(comb_weights) + ksize = [3, 2, 1] + det_score_maps = [] + + for idx, xx in enumerate([x1, x3, x6_2]): + if self.training: + instance_max = torch.max(xx) + self.moving_avg_params[idx].data = self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01 + else: + pass + + alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx]) + + score_vol = alpha * beta + det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0] + det_score_map = F.interpolate(det_score_map, size=x.shape[2:], mode='bilinear', align_corners=True) + det_score_map = comb_weights[idx] * det_score_map + det_score_maps.append(det_score_map) + + det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0) + # print([param.data for param in self.moving_avg_params]) + + return x6_2, det_score_map, x1, x3 diff --git a/imcui/third_party/DarkFeat/nets/loss.py b/imcui/third_party/DarkFeat/nets/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd42b4214d021137ddfe72771ccad0264d2321f --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/loss.py @@ -0,0 +1,260 @@ +import torch +import torch.nn.functional as F + +from .geom import rnd_sample, interpolate, get_dist_mat + + +def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1, + score_map0, score_map1, batch_size, num_corr, loss_type, config): + joint_loss = 0. + accuracy = 0. + all_valid_pos0 = [] + all_valid_pos1 = [] + all_valid_match = [] + for i in range(batch_size): + # random sample + valid_pos0, valid_pos1 = rnd_sample([pos0[i], pos1[i]], num_corr) + valid_num = valid_pos0.shape[0] + + valid_feat0 = interpolate(valid_pos0 / 4, dense_feat_map0[i]) + valid_feat1 = interpolate(valid_pos1 / 4, dense_feat_map1[i]) + + valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) + valid_feat1 = F.normalize(valid_feat1, p=2, dim=-1) + + valid_score0 = interpolate(valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False) + valid_score1 = interpolate(valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False) + + if config['network']['det']['corr_weight']: + corr_weight = valid_score0 * valid_score1 + else: + corr_weight = None + + safe_radius = config['network']['det']['safe_radius'] + if safe_radius > 0: + radius_mask_row = get_dist_mat( + valid_pos1, valid_pos1, "euclidean_dist_no_norm") + radius_mask_row = torch.le(radius_mask_row, safe_radius) + radius_mask_col = get_dist_mat( + valid_pos0, valid_pos0, "euclidean_dist_no_norm") + radius_mask_col = torch.le(radius_mask_col, safe_radius) + radius_mask_row = radius_mask_row.float() - torch.eye(valid_num, device=radius_mask_row.device) + radius_mask_col = radius_mask_col.float() - torch.eye(valid_num, device=radius_mask_col.device) + else: + radius_mask_row = None + radius_mask_col = None + + if valid_num < 32: + si_loss, si_accuracy, matched_mask = 0., 1., torch.zeros((1, valid_num)).bool() + else: + si_loss, si_accuracy, matched_mask = make_structured_loss( + torch.unsqueeze(valid_feat0, 0), torch.unsqueeze(valid_feat1, 0), + loss_type=loss_type, + radius_mask_row=radius_mask_row, radius_mask_col=radius_mask_col, + corr_weight=torch.unsqueeze(corr_weight, 0) if corr_weight is not None else None + ) + + joint_loss += si_loss / batch_size + accuracy += si_accuracy / batch_size + all_valid_match.append(torch.squeeze(matched_mask, dim=0)) + all_valid_pos0.append(valid_pos0) + all_valid_pos1.append(valid_pos1) + + return joint_loss, accuracy + + +def make_structured_loss(feat_anc, feat_pos, + loss_type='RATIO', inlier_mask=None, + radius_mask_row=None, radius_mask_col=None, + corr_weight=None, dist_mat=None): + """ + Structured loss construction. + Args: + feat_anc, feat_pos: Feature matrix. + loss_type: Loss type. + inlier_mask: + Returns: + + """ + batch_size = feat_anc.shape[0] + num_corr = feat_anc.shape[1] + if inlier_mask is None: + inlier_mask = torch.ones((batch_size, num_corr), device=feat_anc.device).bool() + inlier_num = torch.count_nonzero(inlier_mask.float(), dim=-1) + + if loss_type == 'L2NET' or loss_type == 'CIRCLE': + dist_type = 'cosine_dist' + elif loss_type.find('HARD') >= 0: + dist_type = 'euclidean_dist' + else: + raise NotImplementedError() + + if dist_mat is None: + dist_mat = get_dist_mat(feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type).unsqueeze(0) + pos_vec = dist_mat[0].diag().unsqueeze(0) + + if loss_type.find('HARD') >= 0: + neg_margin = 1 + dist_mat_without_min_on_diag = dist_mat + \ + 10 * torch.unsqueeze(torch.eye(num_corr, device=dist_mat.device), dim=0) + mask = torch.le(dist_mat_without_min_on_diag, 0.008).float() + dist_mat_without_min_on_diag += mask*10 + + if radius_mask_row is not None: + hard_neg_dist_row = dist_mat_without_min_on_diag + 10 * radius_mask_row + else: + hard_neg_dist_row = dist_mat_without_min_on_diag + if radius_mask_col is not None: + hard_neg_dist_col = dist_mat_without_min_on_diag + 10 * radius_mask_col + else: + hard_neg_dist_col = dist_mat_without_min_on_diag + + hard_neg_dist_row = torch.min(hard_neg_dist_row, dim=-1)[0] + hard_neg_dist_col = torch.min(hard_neg_dist_col, dim=-2)[0] + + if loss_type == 'HARD_TRIPLET': + loss_row = torch.clamp(neg_margin + pos_vec - hard_neg_dist_row, min=0) + loss_col = torch.clamp(neg_margin + pos_vec - hard_neg_dist_col, min=0) + elif loss_type == 'HARD_CONTRASTIVE': + pos_margin = 0.2 + pos_loss = torch.clamp(pos_vec - pos_margin, min=0) + loss_row = pos_loss + torch.clamp(neg_margin - hard_neg_dist_row, min=0) + loss_col = pos_loss + torch.clamp(neg_margin - hard_neg_dist_col, min=0) + else: + raise NotImplementedError() + + elif loss_type == 'CIRCLE': + log_scale = 512 + m = 0.1 + neg_mask_row = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0) + if radius_mask_row is not None: + neg_mask_row += radius_mask_row + neg_mask_col = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0) + if radius_mask_col is not None: + neg_mask_col += radius_mask_col + + pos_margin = 1 - m + neg_margin = m + pos_optimal = 1 + m + neg_optimal = -m + + neg_mat_row = dist_mat - 128 * neg_mask_row + neg_mat_col = dist_mat - 128 * neg_mask_col + + lse_positive = torch.logsumexp(-log_scale * (pos_vec[..., None] - pos_margin) * \ + torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(), dim=-1) + + lse_negative_row = torch.logsumexp(log_scale * (neg_mat_row - neg_margin) * \ + torch.clamp(neg_mat_row - neg_optimal, min=0).detach(), dim=-1) + + lse_negative_col = torch.logsumexp(log_scale * (neg_mat_col - neg_margin) * \ + torch.clamp(neg_mat_col - neg_optimal, min=0).detach(), dim=-2) + + loss_row = F.softplus(lse_positive + lse_negative_row) / log_scale + loss_col = F.softplus(lse_positive + lse_negative_col) / log_scale + + else: + raise NotImplementedError() + + if dist_type == 'cosine_dist': + err_row = dist_mat - torch.unsqueeze(pos_vec, -1) + err_col = dist_mat - torch.unsqueeze(pos_vec, -2) + elif dist_type == 'euclidean_dist' or dist_type == 'euclidean_dist_no_norm': + err_row = torch.unsqueeze(pos_vec, -1) - dist_mat + err_col = torch.unsqueeze(pos_vec, -2) - dist_mat + else: + raise NotImplementedError() + if radius_mask_row is not None: + err_row = err_row - 10 * radius_mask_row + if radius_mask_col is not None: + err_col = err_col - 10 * radius_mask_col + err_row = torch.sum(torch.clamp(err_row, min=0), dim=-1) + err_col = torch.sum(torch.clamp(err_col, min=0), dim=-2) + + loss = 0 + accuracy = 0 + + tot_loss = (loss_row + loss_col) / 2 + if corr_weight is not None: + tot_loss = tot_loss * corr_weight + + for i in range(batch_size): + if corr_weight is not None: + loss += torch.sum(tot_loss[i][inlier_mask[i]]) / \ + (torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6) + else: + loss += torch.mean(tot_loss[i][inlier_mask[i]]) + cnt_err_row = torch.count_nonzero(err_row[i][inlier_mask[i]]).float() + cnt_err_col = torch.count_nonzero(err_col[i][inlier_mask[i]]).float() + tot_err = cnt_err_row + cnt_err_col + if inlier_num[i] != 0: + accuracy += 1. - tot_err / inlier_num[i] / batch_size / 2. + else: + accuracy += 1. + + matched_mask = torch.logical_and(torch.eq(err_row, 0), torch.eq(err_col, 0)) + matched_mask = torch.logical_and(matched_mask, inlier_mask) + + loss /= batch_size + accuracy /= batch_size + + return loss, accuracy, matched_mask + + +# for the neighborhood areas of keypoints extracted from normal image, the score from noise_score_map should be close +# for the rest, the noise image's score should less than normal image +# input: score_map [batch_size, H, W, 1]; indices [2, k, 2] +# output: loss [scalar] +def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, thld=0.): + H, W = score_map.shape[1:3] + loss = 0 + for i in range(batch_size): + kpts_coords = indices[i].T # (2, num_kpts) + mask = torch.zeros([H, W], device=score_map.device) + mask[kpts_coords.cpu().numpy()] = 1 + + # using 3x3 kernel to put kpts' neightborhood area into the mask + kernel = torch.ones([1, 1, 3, 3], device=score_map.device) + mask = F.conv2d(mask.unsqueeze(0).unsqueeze(0), kernel, padding=1)[0, 0] > 0 + + loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask) + loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask)) + + loss += loss1 + loss += loss2 + + if i == 0: + first_mask = mask + + return loss, first_mask + + +def make_noise_score_map_loss_labelmap(score_map, noise_score_map, labelmap, batch_size, thld=0.): + H, W = score_map.shape[1:3] + loss = 0 + for i in range(batch_size): + # using 3x3 kernel to put kpts' neightborhood area into the mask + kernel = torch.ones([1, 1, 3, 3], device=score_map.device) + mask = F.conv2d(labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1)[0, 0] > 0 + + loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask) + loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask)) + + loss += loss1 + loss += loss2 + + if i == 0: + first_mask = mask + + return loss, first_mask + + +def make_score_map_peakiness_loss(score_map, scores, batch_size): + H, W = score_map.shape[1:3] + loss = 0 + + for i in range(batch_size): + loss += torch.mean(scores[i]) - torch.mean(score_map[i]) + + loss /= batch_size + return 1 - loss diff --git a/imcui/third_party/DarkFeat/nets/multi_sampler.py b/imcui/third_party/DarkFeat/nets/multi_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..dc400fb2afeb50575cd81d3c01b605bea6db1121 --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/multi_sampler.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from .geom import rnd_sample, interpolate + +class MultiSampler (nn.Module): + """ Similar to NghSampler, but doesnt warp the 2nd image. + Distance to GT => 0 ... pos_d ... neg_d ... ngh + Pixel label => + + + + + + 0 0 - - - - - - - + + Subsample on query side: if > 0, regular grid + < 0, random points + In both cases, the number of query points is = W*H/subq**2 + """ + def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None, + maxpool_pos=True, subd_neg=0): + nn.Module.__init__(self) + assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) + self.ngh = ngh + self.pos_d = pos_d + self.neg_d = neg_d + assert subd <= ngh or ngh == 0 + assert subq != 0 + self.sub_q = subq + self.sub_d = subd + self.sub_d_neg = subd_neg + if border is None: border = ngh + assert border >= ngh, 'border has to be larger than ngh' + self.border = border + self.maxpool_pos = maxpool_pos + self.precompute_offsets() + + def precompute_offsets(self): + pos_d2 = self.pos_d**2 + neg_d2 = self.neg_d**2 + rad2 = self.ngh**2 + rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + pos = [] + neg = [] + for j in range(-rad, rad+1, self.sub_d): + for i in range(-rad, rad+1, self.sub_d): + d2 = i*i + j*j + if d2 <= pos_d2: + pos.append( (i,j) ) + elif neg_d2 <= d2 <= rad2: + neg.append( (i,j) ) + + self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t()) + self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t()) + + + def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=2500): + pscores_ls, nscores_ls, distractors_ls = [], [], [] + valid_feat0_ls = [] + noise_pscores_ls, noise_nscores_ls, noise_distractors_ls = [], [], [] + valid_noise_feat0_ls = [] + valid_pos1_ls, valid_pos2_ls = [], [] + qconf_ls = [] + noise_qconf_ls = [] + mask_ls = [] + + for i in range(B): + tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \ + * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border) + + selected_pos0 = pos0[i][tmp_mask] + selected_pos1 = pos1[i][tmp_mask] + valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N) + + # sample features from first image + valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128] + valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128] + qconf = interpolate(valid_pos0 / 4, conf0[i]) + + valid_noise_feat0 = interpolate(valid_pos0 / 4, noise_feat0[i]) # [N, 128] + valid_noise_feat0 = F.normalize(valid_noise_feat0, p=2, dim=-1) # [N, 128] + noise_qconf = interpolate(valid_pos0 / 4, noise_conf0[i]) + + # sample GT from second image + mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \ + * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H) + + def clamp(xy): + xy = xy + torch.clamp(xy[0], 0, H-1, out=xy[0]) + torch.clamp(xy[1], 0, W-1, out=xy[1]) + return xy + + # compute positive scores + valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] + valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] + valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128] + valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128] + valid_noise_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128] + valid_noise_feat1p = F.normalize(valid_noise_feat1p, p=2, dim=-1) # [29, N, 128] + + pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29] + pscores, pos = pscores.max(dim=1, keepdim=True) + sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device)) + qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2 + noise_pscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1p).sum(dim=-1).t() # [N, 29] + noise_pscores, noise_pos = noise_pscores.max(dim=1, keepdim=True) + noise_sel = clamp(valid_pos1.t() + self.pos_offsets[:,noise_pos.view(-1)].to(valid_pos1.device)) + noise_qconf = (noise_qconf + interpolate(noise_sel.t() / 4, noise_conf1[i]))/2 + + # compute negative scores + valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] + valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] + valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128] + valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128] + nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29] + valid_noise_feat1n = interpolate(valid_pos1n / 4, noise_feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128] + valid_noise_feat1n = F.normalize(valid_noise_feat1n, p=2, dim=-1) # [29, N, 128] + noise_nscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1n).sum(dim=-1).t() # [N, 29] + + if self.sub_d_neg: + valid_pos2 = rnd_sample([selected_pos1], N)[0] + distractors = interpolate(valid_pos2 / 4, feat1[i]) + distractors = F.normalize(distractors, p=2, dim=-1) + noise_distractors = interpolate(valid_pos2 / 4, noise_feat1[i]) + noise_distractors = F.normalize(noise_distractors, p=2, dim=-1) + + pscores_ls.append(pscores) + nscores_ls.append(nscores) + distractors_ls.append(distractors) + valid_feat0_ls.append(valid_feat0) + noise_pscores_ls.append(noise_pscores) + noise_nscores_ls.append(noise_nscores) + noise_distractors_ls.append(noise_distractors) + valid_noise_feat0_ls.append(valid_noise_feat0) + valid_pos1_ls.append(valid_pos1) + valid_pos2_ls.append(valid_pos2) + qconf_ls.append(qconf) + noise_qconf_ls.append(noise_qconf) + mask_ls.append(mask) + + N = np.min([len(i) for i in qconf_ls]) + + # merge batches + qconf = torch.stack([i[:N] for i in qconf_ls], dim=0).squeeze(-1) + mask = torch.stack([i[:N] for i in mask_ls], dim=0) + pscores = torch.cat([i[:N] for i in pscores_ls], dim=0) + nscores = torch.cat([i[:N] for i in nscores_ls], dim=0) + distractors = torch.cat([i[:N] for i in distractors_ls], dim=0) + valid_feat0 = torch.cat([i[:N] for i in valid_feat0_ls], dim=0) + valid_pos1 = torch.cat([i[:N] for i in valid_pos1_ls], dim=0) + valid_pos2 = torch.cat([i[:N] for i in valid_pos2_ls], dim=0) + + noise_qconf = torch.stack([i[:N] for i in noise_qconf_ls], dim=0).squeeze(-1) + noise_pscores = torch.cat([i[:N] for i in noise_pscores_ls], dim=0) + noise_nscores = torch.cat([i[:N] for i in noise_nscores_ls], dim=0) + noise_distractors = torch.cat([i[:N] for i in noise_distractors_ls], dim=0) + valid_noise_feat0 = torch.cat([i[:N] for i in valid_noise_feat0_ls], dim=0) + + # remove scores that corresponds to positives or nulls + dscores = torch.matmul(valid_feat0, distractors.t()) + noise_dscores = torch.matmul(valid_noise_feat0, noise_distractors.t()) + + dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2 + b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1) + dis2 += (b != b[:,None]).long() * self.neg_d**2 + dscores[dis2 < self.neg_d**2] = 0 + noise_dscores[dis2 < self.neg_d**2] = 0 + scores = torch.cat((pscores, nscores, dscores), dim=1) + noise_scores = torch.cat((noise_pscores, noise_nscores, noise_dscores), dim=1) + + gt = scores.new_zeros(scores.shape, dtype=torch.uint8) + gt[:, :pscores.shape[1]] = 1 + + return scores, noise_scores, gt, mask, qconf, noise_qconf diff --git a/imcui/third_party/DarkFeat/nets/noise_reliability_loss.py b/imcui/third_party/DarkFeat/nets/noise_reliability_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9efddae149653c225ee7f2c1eb5fed5f92cef15c --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/noise_reliability_loss.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +from .reliability_loss import APLoss + + +class MultiPixelAPLoss (nn.Module): + """ Computes the pixel-wise AP loss: + Given two images and ground-truth optical flow, computes the AP per pixel. + + feat1: (B, C, H, W) pixel-wise features extracted from img1 + feat2: (B, C, H, W) pixel-wise features extracted from img2 + aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 + """ + def __init__(self, sampler, nq=20): + nn.Module.__init__(self) + self.aploss = APLoss(nq, min=0, max=1, euc=False) + self.sampler = sampler + self.base = 0.25 + self.dec_base = 0.20 + + def loss_from_ap(self, ap, rel, noise_ap, noise_rel): + dec_ap = torch.clamp(ap - noise_ap, min=0, max=1) + return (1 - ap*noise_rel - (1-noise_rel)*self.base), (1. - dec_ap*(1-noise_rel) - noise_rel*self.dec_base) + + def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500): + # subsample things + scores, noise_scores, gt, msk, qconf, noise_qconf = self.sampler(feat0, feat1, noise_feat0, noise_feat1, \ + conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500) + + # compute pixel-wise AP + n = qconf.numel() + if n == 0: return 0, 0 + scores, noise_scores, gt = scores.view(n,-1), noise_scores, gt.view(n,-1) + ap = self.aploss(scores, gt).view(msk.shape) + noise_ap = self.aploss(noise_scores, gt).view(msk.shape) + + pixel_loss = self.loss_from_ap(ap, qconf, noise_ap, noise_qconf) + + loss = pixel_loss[0][msk].mean(), pixel_loss[1][msk].mean() + return loss \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/nets/reliability_loss.py b/imcui/third_party/DarkFeat/nets/reliability_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..527f9886a2d4785680bac52ff2fa20033b8d8920 --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/reliability_loss.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import numpy as np + + +class APLoss (nn.Module): + """ differentiable AP loss, through quantization. + + Input: (N, M) values in [min, max] + label: (N, M) values in {0, 1} + + Returns: list of query AP (for each n in {1..N}) + Note: typically, you want to minimize 1 - mean(AP) + """ + def __init__(self, nq=25, min=0, max=1, euc=False): + nn.Module.__init__(self) + assert isinstance(nq, int) and 2 <= nq <= 100 + self.nq = nq + self.min = min + self.max = max + self.euc = euc + gap = max - min + assert gap > 0 + + # init quantizer = non-learnable (fixed) convolution + self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True) + a = (nq-1) / gap + #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1) + q.weight.data[:nq] = -a + q.bias.data[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x) + #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1) + q.weight.data[nq:] = a + q.bias.data[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x) + # first and last one are special: just horizontal straight line + q.weight.data[0] = q.weight.data[-1] = 0 + q.bias.data[0] = q.bias.data[-1] = 1 + + def compute_AP(self, x, label): + N, M = x.shape + # print(x.shape, label.shape) + if self.euc: # euclidean distance in same range than similarities + x = 1 - torch.sqrt(2.001 - 2*x) + + # quantize all predictions + q = self.quantizer(x.unsqueeze(1)) + q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M [1600, 20, 1681] + + nbs = q.sum(dim=-1) # number of samples N x Q = c + rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q + prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision + rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1] + + ap = (prec * rec).sum(dim=-1) # per-image AP + return ap + + def forward(self, x, label): + assert x.shape == label.shape # N x M + return self.compute_AP(x, label) + + +class PixelAPLoss (nn.Module): + """ Computes the pixel-wise AP loss: + Given two images and ground-truth optical flow, computes the AP per pixel. + + feat1: (B, C, H, W) pixel-wise features extracted from img1 + feat2: (B, C, H, W) pixel-wise features extracted from img2 + aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 + """ + def __init__(self, sampler, nq=20): + nn.Module.__init__(self) + self.aploss = APLoss(nq, min=0, max=1, euc=False) + self.name = 'pixAP' + self.sampler = sampler + + def loss_from_ap(self, ap, rel): + return 1 - ap + + def forward(self, feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200): + # subsample things + scores, gt, msk, qconf = self.sampler(feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200) + + # compute pixel-wise AP + n = qconf.numel() + if n == 0: return 0 + scores, gt = scores.view(n,-1), gt.view(n,-1) + ap = self.aploss(scores, gt).view(msk.shape) + + pixel_loss = self.loss_from_ap(ap, qconf) + + loss = pixel_loss[msk].mean() + return loss + + +class ReliabilityLoss (PixelAPLoss): + """ same than PixelAPLoss, but also train a pixel-wise confidence + that this pixel is going to have a good AP. + """ + def __init__(self, sampler, base=0.5, **kw): + PixelAPLoss.__init__(self, sampler, **kw) + assert 0 <= base < 1 + self.base = base + + def loss_from_ap(self, ap, rel): + return 1 - ap*rel - (1-rel)*self.base + diff --git a/imcui/third_party/DarkFeat/nets/sampler.py b/imcui/third_party/DarkFeat/nets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..b732a3671872d5675be9826f76b0818d3b99d466 --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/sampler.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from .geom import rnd_sample, interpolate + +class NghSampler2 (nn.Module): + """ Similar to NghSampler, but doesnt warp the 2nd image. + Distance to GT => 0 ... pos_d ... neg_d ... ngh + Pixel label => + + + + + + 0 0 - - - - - - - + + Subsample on query side: if > 0, regular grid + < 0, random points + In both cases, the number of query points is = W*H/subq**2 + """ + def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None, + maxpool_pos=True, subd_neg=0): + nn.Module.__init__(self) + assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) + self.ngh = ngh + self.pos_d = pos_d + self.neg_d = neg_d + assert subd <= ngh or ngh == 0 + assert subq != 0 + self.sub_q = subq + self.sub_d = subd + self.sub_d_neg = subd_neg + if border is None: border = ngh + assert border >= ngh, 'border has to be larger than ngh' + self.border = border + self.maxpool_pos = maxpool_pos + self.precompute_offsets() + + def precompute_offsets(self): + pos_d2 = self.pos_d**2 + neg_d2 = self.neg_d**2 + rad2 = self.ngh**2 + rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + pos = [] + neg = [] + for j in range(-rad, rad+1, self.sub_d): + for i in range(-rad, rad+1, self.sub_d): + d2 = i*i + j*j + if d2 <= pos_d2: + pos.append( (i,j) ) + elif neg_d2 <= d2 <= rad2: + neg.append( (i,j) ) + + self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t()) + self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t()) + + def gen_grid(self, step, B, H, W, dev): + b1 = torch.arange(B, device=dev) + if step > 0: + # regular grid + x1 = torch.arange(self.border, W-self.border, step, device=dev) + y1 = torch.arange(self.border, H-self.border, step, device=dev) + H1, W1 = len(y1), len(x1) + x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1) + y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1) + b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1) + shape = (B, H1, W1) + else: + # randomly spread + n = (H - 2*self.border) * (W - 2*self.border) // step**2 + x1 = torch.randint(self.border, W-self.border, (n,), device=dev) + y1 = torch.randint(self.border, H-self.border, (n,), device=dev) + x1 = x1[None,:].expand(B,n).reshape(-1) + y1 = y1[None,:].expand(B,n).reshape(-1) + b1 = b1[:,None].expand(B,n).reshape(-1) + shape = (B, n) + return b1, y1, x1, shape + + def forward(self, feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=2500): + pscores_ls, nscores_ls, distractors_ls = [], [], [] + valid_feat0_ls = [] + valid_pos1_ls, valid_pos2_ls = [], [] + qconf_ls = [] + mask_ls = [] + + for i in range(B): + # positions in the first image + tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \ + * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border) + + selected_pos0 = pos0[i][tmp_mask] + selected_pos1 = pos1[i][tmp_mask] + valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N) + + # sample features from first image + valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128] + valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128] + qconf = interpolate(valid_pos0 / 4, conf0[i]) + + # sample GT from second image + mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \ + * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H) + + def clamp(xy): + xy = xy + torch.clamp(xy[0], 0, H-1, out=xy[0]) + torch.clamp(xy[1], 0, W-1, out=xy[1]) + return xy + + # compute positive scores + valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] + valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] + valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128] + valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128] + + pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29] + pscores, pos = pscores.max(dim=1, keepdim=True) + sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device)) + qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2 + + # compute negative scores + valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N] + valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2] + valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128] + valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128] + nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29] + + if self.sub_d_neg: + valid_pos2 = rnd_sample([selected_pos1], N)[0] + distractors = interpolate(valid_pos2 / 4, feat1[i]) + distractors = F.normalize(distractors, p=2, dim=-1) + + pscores_ls.append(pscores) + nscores_ls.append(nscores) + distractors_ls.append(distractors) + valid_feat0_ls.append(valid_feat0) + valid_pos1_ls.append(valid_pos1) + valid_pos2_ls.append(valid_pos2) + qconf_ls.append(qconf) + mask_ls.append(mask) + + N = np.min([len(i) for i in qconf_ls]) + + # merge batches + qconf = torch.stack([i[:N] for i in qconf_ls], dim=0).squeeze(-1) + mask = torch.stack([i[:N] for i in mask_ls], dim=0) + pscores = torch.cat([i[:N] for i in pscores_ls], dim=0) + nscores = torch.cat([i[:N] for i in nscores_ls], dim=0) + distractors = torch.cat([i[:N] for i in distractors_ls], dim=0) + valid_feat0 = torch.cat([i[:N] for i in valid_feat0_ls], dim=0) + valid_pos1 = torch.cat([i[:N] for i in valid_pos1_ls], dim=0) + valid_pos2 = torch.cat([i[:N] for i in valid_pos2_ls], dim=0) + + dscores = torch.matmul(valid_feat0, distractors.t()) + dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2 + b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1) + dis2 += (b != b[:,None]).long() * self.neg_d**2 + dscores[dis2 < self.neg_d**2] = 0 + scores = torch.cat((pscores, nscores, dscores), dim=1) + + gt = scores.new_zeros(scores.shape, dtype=torch.uint8) + gt[:, :pscores.shape[1]] = 1 + + return scores, gt, mask, qconf diff --git a/imcui/third_party/DarkFeat/nets/score.py b/imcui/third_party/DarkFeat/nets/score.py new file mode 100644 index 0000000000000000000000000000000000000000..a78cf1c893bc338c12803697d55e121a75171f2c --- /dev/null +++ b/imcui/third_party/DarkFeat/nets/score.py @@ -0,0 +1,116 @@ +import torch +import torch.nn.functional as F +import numpy as np + +from .geom import gather_nd + +# input: [batch_size, C, H, W] +# output: [batch_size, C, H, W], [batch_size, C, H, W] +def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1): + inputs = inputs / moving_instance_max + + batch_size, C, H, W = inputs.shape + + pad_size = ksize // 2 + (dilation - 1) + kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize) + + pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect') + + avg_spatial_inputs = F.conv2d( + pad_inputs, + kernel, + stride=1, + dilation=dilation, + padding=0, + groups=C + ) + avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1 + + alpha = F.softplus(inputs - avg_spatial_inputs) + beta = F.softplus(inputs - avg_channel_inputs) + + return alpha, beta + + +# input: score_map [batch_size, 1, H, W] +# output: indices [2, k, 2], scores [2, k] +def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5): + h = score_map.shape[2] + w = score_map.shape[3] + + mask = score_map > score_thld + if nms_size > 0: + nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2) + nms_mask = torch.eq(score_map, nms_mask) + mask = torch.logical_and(nms_mask, mask) + if eof_size > 0: + eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device) + eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0) + eof_mask = eof_mask.bool() + mask = torch.logical_and(eof_mask, mask) + if edge_thld > 0: + non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld) + mask = torch.logical_and(non_edge_mask, mask) + + bs = score_map.shape[0] + if bs is None: + indices = torch.nonzero(mask)[0] + scores = gather_nd(score_map, indices)[0] + sample = torch.sort(scores, descending=True)[1][0:k] + indices = indices[sample].unsqueeze(0) + scores = scores[sample].unsqueeze(0) + else: + indices = [] + scores = [] + for i in range(bs): + tmp_mask = mask[i][0] + tmp_score_map = score_map[i][0] + tmp_indices = torch.nonzero(tmp_mask) + tmp_scores = gather_nd(tmp_score_map, tmp_indices) + tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k] + tmp_indices = tmp_indices[tmp_sample] + tmp_scores = tmp_scores[tmp_sample] + indices.append(tmp_indices) + scores.append(tmp_scores) + try: + indices = torch.stack(indices, dim=0) + scores = torch.stack(scores, dim=0) + except: + min_num = np.min([len(i) for i in indices]) + indices = torch.stack([i[:min_num] for i in indices], dim=0) + scores = torch.stack([i[:min_num] for i in scores], dim=0) + return indices, scores + + +def edge_mask(inputs, n_channel, dilation=1, edge_thld=5): + b, c, h, w = inputs.size() + device = inputs.device + + dii_filter = torch.tensor( + [[0, 1., 0], [0, -2., 0], [0, 1., 0]] + ).view(1, 1, 3, 3) + dij_filter = 0.25 * torch.tensor( + [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]] + ).view(1, 1, 3, 3) + djj_filter = torch.tensor( + [[0, 0, 0], [1., -2., 1.], [0, 0, 0]] + ).view(1, 1, 3, 3) + + dii = F.conv2d( + inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation + ).view(b, c, h, w) + dij = F.conv2d( + inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation + ).view(b, c, h, w) + djj = F.conv2d( + inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation + ).view(b, c, h, w) + + det = dii * djj - dij * dij + tr = dii + djj + del dii, dij, djj + + threshold = (edge_thld + 1) ** 2 / edge_thld + is_not_edge = torch.min(tr * tr / det <= threshold, det > 0) + + return is_not_edge diff --git a/imcui/third_party/DarkFeat/pose_estimation.py b/imcui/third_party/DarkFeat/pose_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..c87877191e7e31c3bc0a362d7d481dfd5d4b5757 --- /dev/null +++ b/imcui/third_party/DarkFeat/pose_estimation.py @@ -0,0 +1,137 @@ +import argparse +import cv2 +import numpy as np +import os +import math +import subprocess +from tqdm import tqdm + + +def compute_essential(matched_kp1, matched_kp2, K): + pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + K_1 = np.eye(3) + # Estimate the homography between the matches using RANSAC + ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000) + if ransac_inliers is None or ransac_model.shape != (3,3): + ransac_inliers = np.array([]) + ransac_model = None + return ransac_model, ransac_inliers, pts1, pts2 + + +def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): + """Compute the angular error between two rotation matrices and two translation vectors. + Keyword arguments: + R -- 2D numpy array containing an estimated rotation + gt_R -- 2D numpy array containing the corresponding ground truth rotation + t -- 2D numpy array containing an estimated translation as column + gt_t -- 2D numpy array containing the corresponding ground truth translation + """ + + inliers = inliers.ravel() + R = np.eye(3) + t = np.zeros((3,1)) + sst = True + try: + _, R, t, _ = cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), inliers) + except: + sst = False + # calculate angle between provided rotations + # + if sst: + dR = np.matmul(R, np.transpose(R_GT)) + dR = cv2.Rodrigues(dR)[0] + dR = np.linalg.norm(dR) * 180 / math.pi + + # calculate angle between provided translations + dT = float(np.dot(t_GT.T, t)) + dT /= float(np.linalg.norm(t_GT)) + + if dT > 1 or dT < -1: + print("Domain warning! dT:",dT) + dT = max(-1,min(1,dT)) + dT = math.acos(dT) * 180 / math.pi + dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation + else: + dR, dT = 180.0, 180.0 + return dR, dT + + +def pose_evaluation(result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT, t_GT): + try: + m_kp1 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_1/'+dark_name1) + m_kp2 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_2/'+dark_name2) + except: + return 180.0, 180.0 + try: + E, inliers, pts1, pts2 = compute_essential(m_kp1, m_kp2, K) + except: + E, inliers, pts1, pts2 = np.zeros((3, 3)), np.array([]), None, None + dR, dT = compute_error(R_GT, t_GT, E, pts1, pts2, inliers) + return dR, dT + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--histeq', action='store_true') + parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/') + opt = parser.parse_args() + + sizer = (960, 640) + focallength_x = 4.504986436499113e+03/(6744/sizer[0]) + focallength_y = 4.513311442889859e+03/(4502/sizer[1]) + K = np.eye(3) + K[0,0] = focallength_x + K[1,1] = focallength_y + K[0,2] = 3.363322177533149e+03/(6744/sizer[0]) + K[1,2] = 2.291824660547715e+03/(4502/sizer[1]) + Kinv = np.linalg.inv(K) + Kinvt = np.transpose(Kinv) + + PE_MT = np.zeros((6, 8)) + + enhancer = 'None' if not opt.histeq else 'HistEQ' + + for scene in ['Indoor', 'Outdoor']: + dir_base = opt.dataset_dir + '/' + scene + '/' + base_save = 'result_errors/' + scene + '/' + pair_list = sorted(os.listdir(dir_base)) + + os.makedirs(base_save, exist_ok=True) + + for pair in tqdm(pair_list): + opention = 1 + if scene == 'Outdoor': + pass + else: + if int(pair[4::]) <= 17: + opention = 0 + else: + pass + name = [] + files = sorted(os.listdir(dir_base+pair)) + for file_ in files: + if file_.endswith('.cr2'): + name.append(file_[0:9]) + ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800'] + if opention == 1: + Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5'] + else: + Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1'] + + E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy') + F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv) + R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy') + t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy') + result_base_dir ='result/' +scene+'/'+pair+'/' + for iso in ISO: + for ex in Shutter_speed: + dark_name1 = name[0]+iso+'_'+ex+'_'+scene+'.npy' + dark_name2 = name[1]+iso+'_'+ex+'_'+scene+'.npy' + + dr, dt = pose_evaluation(result_base_dir,dark_name1,dark_name2,enhancer,K,R_GT,t_GT) + PE_MT[Shutter_speed.index(ex),ISO.index(iso)] = max(dr, dt) + + subprocess.check_output(['mkdir', '-p', base_save + pair + f'/{enhancer}/']) + np.save(base_save + pair + f'/{enhancer}/Pose_error_DarkFeat.npy', PE_MT) + \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/raw_preprocess.py b/imcui/third_party/DarkFeat/raw_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..226155a84e97f15782d3650f4ef6b3fa1880e07b --- /dev/null +++ b/imcui/third_party/DarkFeat/raw_preprocess.py @@ -0,0 +1,62 @@ +import glob +import rawpy +import cv2 +import os +import numpy as np +import colour_demosaicing +from tqdm import tqdm + + +def process_raw(args, path, w_new, h_new): + raw = rawpy.imread(str(path)).raw_image_visible + if '_00200_' in str(path) or '_00100_' in str(path): + raw = np.clip(raw.astype('float32') - 512, 0, 65535) + else: + raw = np.clip(raw.astype('float32') - 2048, 0, 65535) + img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32') + img = np.clip(img, 0, 16383) + + # HistEQ start + if args.histeq: + img2 = np.zeros_like(img) + for i in range(3): + hist,bins = np.histogram(img[..., i].flatten(),16384,[0,16384]) + cdf = hist.cumsum() + cdf_normalized = cdf * float(hist.max()) / cdf.max() + cdf_m = np.ma.masked_equal(cdf,0) + cdf_m = (cdf_m - cdf_m.min())*16383/(cdf_m.max()-cdf_m.min()) + cdf = np.ma.filled(cdf_m,0).astype('uint16') + img2[..., i] = cdf[img[..., i].astype('int16')] + img[..., i] = img2[..., i].astype('float32') + # HistEQ end + + m = img.mean() + d = np.abs(img - img.mean()).mean() + img = (img - m + 2*d) / 4/d * 255 + image = np.clip(img, 0, 255) + + image = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA) + + if args.histeq: + path=str(path) + os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']), exist_ok=True) + np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']+[path.split('/')[-1].replace('cr2','npy')]), image) + else: + path=str(path) + os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']), exist_ok=True) + np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']+[path.split('/')[-1].replace('cr2','npy')]), image) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--H', type=int, default=int(640)) + parser.add_argument('--W', type=int, default=int(960)) + parser.add_argument('--histeq', action='store_true') + parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/') + args = parser.parse_args() + + path_ls = glob.glob(args.dataset_dir + '/*/pair*/?????/*') + for path in tqdm(path_ls): + process_raw(args, path, args.W, args.H) + diff --git a/imcui/third_party/DarkFeat/read_error.py b/imcui/third_party/DarkFeat/read_error.py new file mode 100644 index 0000000000000000000000000000000000000000..406b92dbd3877a11e51aebc3a705cd8d8d17e173 --- /dev/null +++ b/imcui/third_party/DarkFeat/read_error.py @@ -0,0 +1,56 @@ +import os +import numpy as np +import subprocess + +# def ratio(losses, thresholds=[1,2,3,4,5,6,7,8,9,10]): +def ratio(losses, thresholds=[5,10]): + return [ + '{:.3f}'.format(np.mean(losses < threshold)) + for threshold in thresholds + ] + +if __name__ == '__main__': + scene = 'Indoor' + dir_base = 'result_errors/Indoor/' + save_pt = 'resultfinal_errors/Indoor/' + + subprocess.check_output(['mkdir', '-p', save_pt]) + + with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f: + f.write('5deg 10deg'+'\n') + pair_list = os.listdir(dir_base) + enhancer = os.listdir(dir_base+'/pair9/') + for method in enhancer: + pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method)) + for pose_error in pose_error_list: + error_array = np.expand_dims(np.zeros((6, 8)),axis=2) + for pair in pair_list: + try: + error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2) + except: + print('error in', dir_base+'/'+pair+'/'+method+'/'+pose_error) + continue + error_array = np.concatenate((error_array,error),axis=2) + ratio_result = ratio(error_array[:,:,1::].flatten()) + f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n") + + + scene = 'Outdoor' + dir_base = 'result_errors/Outdoor/' + save_pt = 'resultfinal_errors/Outdoor/' + + subprocess.check_output(['mkdir', '-p', save_pt]) + + with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f: + f.write('5deg 10deg'+'\n') + pair_list = os.listdir(dir_base) + enhancer = os.listdir(dir_base+'/pair9/') + for method in enhancer: + pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method)) + for pose_error in pose_error_list: + error_array = np.expand_dims(np.zeros((6, 8)),axis=2) + for pair in pair_list: + error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2) + error_array = np.concatenate((error_array,error),axis=2) + ratio_result = ratio(error_array[:,:,1::].flatten()) + f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n") diff --git a/imcui/third_party/DarkFeat/run.py b/imcui/third_party/DarkFeat/run.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4c87053d2970fc927d8991aa0dab208f3c4917 --- /dev/null +++ b/imcui/third_party/DarkFeat/run.py @@ -0,0 +1,48 @@ +import cv2 +import yaml +import argparse +import os +from torch.utils.data import DataLoader + +from datasets.gl3d_dataset import GL3DDataset +from trainer import Trainer +from trainer_single_norel import SingleTrainerNoRel +from trainer_single import SingleTrainer + + +if __name__ == '__main__': + # add argument parser + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='./configs/config.yaml') + parser.add_argument('--dataset_dir', type=str, default='/mnt/nvme2n1/hyz/data/GL3D') + parser.add_argument('--data_split', type=str, default='comb') + parser.add_argument('--is_training', type=bool, default=True) + parser.add_argument('--job_name', type=str, default='') + parser.add_argument('--gpu', type=str, default='0') + parser.add_argument('--start_cnt', type=int, default=0) + parser.add_argument('--stage', type=int, default=1) + args = parser.parse_args() + + # load global config + with open(args.config, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # setup dataloader + dataset = GL3DDataset(args.dataset_dir, config['network'], args.data_split, is_training=args.is_training) + data_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4) + + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + + if args.stage == 1: + trainer = SingleTrainerNoRel(config, f'cuda:0', data_loader, args.job_name, args.start_cnt) + elif args.stage == 2: + trainer = SingleTrainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt) + elif args.stage == 3: + trainer = Trainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt) + else: + raise NotImplementedError() + + trainer.train() + + \ No newline at end of file diff --git a/imcui/third_party/DarkFeat/trainer.py b/imcui/third_party/DarkFeat/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ff2af9608e934b6899058d756bb2ab7d0fee2d --- /dev/null +++ b/imcui/third_party/DarkFeat/trainer.py @@ -0,0 +1,348 @@ +import os +import cv2 +import time +import yaml +import torch +import datetime +from tensorboardX import SummaryWriter +import torchvision.transforms as tvf +import torch.nn as nn +import torch.nn.functional as F + +from nets.geom import getK, getWarp, _grid_positions, getWarpNoValidate +from nets.loss import make_detector_loss, make_noise_score_map_loss +from nets.score import extract_kpts +from nets.multi_sampler import MultiSampler +from nets.noise_reliability_loss import MultiPixelAPLoss +from datasets.noise_simulator import NoiseSimulator +from nets.l2net import Quad_L2Net + + +class Trainer: + def __init__(self, config, device, loader, job_name, start_cnt): + self.config = config + self.device = device + self.loader = loader + + # tensorboard writer construction + os.makedirs('./runs/', exist_ok=True) + if job_name != '': + self.log_dir = f'runs/{job_name}' + else: + self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}' + + self.writer = SummaryWriter(self.log_dir) + with open(f'{self.log_dir}/config.yaml', 'w') as f: + yaml.dump(config, f) + + if config['network']['input_type'] == 'gray': + self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device) + elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic': + self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device) + elif config['network']['input_type'] == 'raw': + self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device) + else: + raise NotImplementedError() + + # noise maker + self.noise_maker = NoiseSimulator(device) + + # reliability map conv + self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda() + + # load model + self.cnt = 0 + if start_cnt != 0: + self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth', map_location=device)) + self.cnt = start_cnt + 1 + + # sampler + sampler = MultiSampler(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, + subd_neg=-8,maxpool_pos=True).to(device) + self.reliability_relitive_loss = MultiPixelAPLoss(sampler, nq=20).to(device) + + + # optimizer and scheduler + if self.config['training']['optimizer'] == 'SGD': + self.optimizer = torch.optim.SGD( + [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], + lr=self.config['training']['lr'], + momentum=self.config['training']['momentum'], + weight_decay=self.config['training']['weight_decay'], + ) + elif self.config['training']['optimizer'] == 'Adam': + self.optimizer = torch.optim.Adam( + [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], + lr=self.config['training']['lr'], + weight_decay=self.config['training']['weight_decay'] + ) + else: + raise NotImplementedError() + + self.lr_scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, + step_size=self.config['training']['lr_step'], + gamma=self.config['training']['lr_gamma'], + last_epoch=start_cnt + ) + for param_tensor in self.model.state_dict(): + print(param_tensor, "\t", self.model.state_dict()[param_tensor].size()) + + + def save(self, iter_num): + torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth') + + def load(self, path): + self.model.load_state_dict(torch.load(path)) + + def train(self): + self.model.train() + + for epoch in range(2): + for batch_idx, inputs in enumerate(self.loader): + self.optimizer.zero_grad() + t = time.time() + + # preprocess and add noise + img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt) + img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt) + + img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device) + img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device) + noise_img0 = noise_img0_ori.permute(0, 3, 1, 2).float().to(self.device) + noise_img1 = noise_img1_ori.permute(0, 3, 1, 2).float().to(self.device) + + if self.config['network']['input_type'] == 'rgb': + # 3-channel rgb + RGB_mean = [0.485, 0.456, 0.406] + RGB_std = [0.229, 0.224, 0.225] + norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) + img0 = norm_RGB(img0) + img1 = norm_RGB(img1) + noise_img0 = norm_RGB(noise_img0) + noise_img1 = norm_RGB(noise_img1) + + elif self.config['network']['input_type'] == 'gray': + # 1-channel + img0 = torch.mean(img0, dim=1, keepdim=True) + img1 = torch.mean(img1, dim=1, keepdim=True) + noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True) + noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True) + norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std()) + norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std()) + img0 = norm_gray0(img0) + img1 = norm_gray1(img1) + noise_img0 = norm_gray0(noise_img0) + noise_img1 = norm_gray1(noise_img1) + + elif self.config['network']['input_type'] == 'raw': + # 4-channel + pass + + elif self.config['network']['input_type'] == 'raw-demosaic': + # 3-channel + pass + + else: + raise NotImplementedError() + + desc0, score_map0, _, _ = self.model(img0) + desc1, score_map1, _, _ = self.model(img1) + + conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2] + conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2] + + noise_desc0, noise_score_map0, noise_at0, noise_att0 = self.model(noise_img0) + noise_desc1, noise_score_map1, noise_at1, noise_att1 = self.model(noise_img1) + + noise_conf0 = F.softmax(self.model.clf(torch.abs(noise_desc0)**2.0), dim=1)[:,1:2] + noise_conf1 = F.softmax(self.model.clf(torch.abs(noise_desc1)**2.0), dim=1)[:,1:2] + + cur_feat_size0 = torch.tensor(score_map0.shape[2:]) + cur_feat_size1 = torch.tensor(score_map1.shape[2:]) + + desc0 = desc0.permute(0, 2, 3, 1) + desc1 = desc1.permute(0, 2, 3, 1) + score_map0 = score_map0.permute(0, 2, 3, 1) + score_map1 = score_map1.permute(0, 2, 3, 1) + noise_desc0 = noise_desc0.permute(0, 2, 3, 1) + noise_desc1 = noise_desc1.permute(0, 2, 3, 1) + noise_score_map0 = noise_score_map0.permute(0, 2, 3, 1) + noise_score_map1 = noise_score_map1.permute(0, 2, 3, 1) + conf0 = conf0.permute(0, 2, 3, 1) + conf1 = conf1.permute(0, 2, 3, 1) + noise_conf0 = noise_conf0.permute(0, 2, 3, 1) + noise_conf1 = noise_conf1.permute(0, 2, 3, 1) + + r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device) + r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device) + + pos0 = _grid_positions( + cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device) + + pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate( + pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), + r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + + pos0, pos1, _ = getWarp( + pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), + r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + + reliab_loss_relative = self.reliability_relitive_loss(desc0, desc1, noise_desc0, noise_desc1, conf0, conf1, noise_conf0, noise_conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3]) + + det_structured_loss, det_accuracy = make_detector_loss( + pos0, pos1, desc0, desc1, + score_map0, score_map1, img0.shape[0], + self.config['network']['use_corr_n'], + self.config['network']['loss_type'], + self.config + ) + + det_structured_loss_noise, det_accuracy_noise = make_detector_loss( + pos0, pos1, noise_desc0, noise_desc1, + noise_score_map0, noise_score_map1, img0.shape[0], + self.config['network']['use_corr_n'], + self.config['network']['loss_type'], + self.config + ) + + indices0, scores0 = extract_kpts( + score_map0.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + indices1, scores1 = extract_kpts( + score_map1.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + + noise_score_loss0, mask0 = make_noise_score_map_loss(score_map0, noise_score_map0, indices0, img0.shape[0], thld=0.1) + noise_score_loss1, mask1 = make_noise_score_map_loss(score_map1, noise_score_map1, indices1, img1.shape[0], thld=0.1) + + total_loss = det_structured_loss + det_structured_loss_noise + total_loss += noise_score_loss0 / 2. * 1. + total_loss += noise_score_loss1 / 2. * 1. + total_loss += reliab_loss_relative[0] / 2. * 0.5 + total_loss += reliab_loss_relative[1] / 2. * 0.5 + + self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt) + self.writer.add_scalar("acc/noise_acc", det_accuracy_noise, self.cnt) + self.writer.add_scalar("loss/total_loss", total_loss, self.cnt) + self.writer.add_scalar("loss/noise_score_loss", (noise_score_loss0 + noise_score_loss1) / 2., self.cnt) + self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt) + self.writer.add_scalar("loss/det_loss_noise", det_structured_loss_noise, self.cnt) + print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t)) + # print(f'normal_loss: {det_structured_loss}, noise_loss: {det_structured_loss_noise}, reliab_loss: {reliab_loss_relative[0]}, {reliab_loss_relative[1]}') + + if det_structured_loss != 0: + total_loss.backward() + self.optimizer.step() + self.lr_scheduler.step() + + if self.cnt % 100 == 0: + noise_indices0, noise_scores0 = extract_kpts( + noise_score_map0.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + noise_indices1, noise_scores1 = extract_kpts( + noise_score_map1.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + if self.config['network']['input_type'] == 'raw': + kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0]) + noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0][..., :3] * 255., noise_indices0[0]) + noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0][..., :3] * 255., noise_indices1[0]) + else: + kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0]) + noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0] * 255., noise_indices0[0]) + noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0] * 255., noise_indices1[0]) + + self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC') + self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC') + self.writer.add_image('img0/noise_kpts', noise_kpt_img0, self.cnt, dataformats='HWC') + self.writer.add_image('img1/noise_kpts', noise_kpt_img1, self.cnt, dataformats='HWC') + self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC') + self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC') + self.writer.add_image('img0/noise_score_map', noise_score_map0[0], self.cnt, dataformats='HWC') + self.writer.add_image('img1/noise_score_map', noise_score_map1[0], self.cnt, dataformats='HWC') + self.writer.add_image('img0/kpt_mask', mask0.unsqueeze(2), self.cnt, dataformats='HWC') + self.writer.add_image('img1/kpt_mask', mask1.unsqueeze(2), self.cnt, dataformats='HWC') + self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC') + self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC') + self.writer.add_image('img0/noise_conf', noise_conf0[0], self.cnt, dataformats='HWC') + self.writer.add_image('img1/noise_conf', noise_conf1[0], self.cnt, dataformats='HWC') + + if self.cnt % 5000 == 0: + self.save(self.cnt) + + self.cnt += 1 + + + def showKeyPoints(self, img, indices): + key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1]) + img = img.numpy().astype('uint8') + img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0)) + return img + + + def preprocess(self, img, iter_idx): + if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']: + return img + + raw = self.noise_maker.rgb2raw(img, batched=True) + + if self.config['network']['noise']: + ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + + if self.config['network']['input_type'] == 'raw': + return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)) + + if self.config['network']['input_type'] == 'raw-demosaic': + return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)) + + rgb = self.noise_maker.raw2rgb(raw, batched=True) + if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + return torch.tensor(rgb) + + raise NotImplementedError() + + + def preprocess_noise_pair(self, img, iter_idx): + assert self.config['network']['noise'] + + raw = self.noise_maker.rgb2raw(img, batched=True) + + ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + + if self.config['network']['input_type'] == 'raw': + return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \ + torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) + + if self.config['network']['input_type'] == 'raw-demosaic': + return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \ + torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) + + noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True) + if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + return img, torch.tensor(noise_rgb) + + raise NotImplementedError() diff --git a/imcui/third_party/DarkFeat/trainer_single.py b/imcui/third_party/DarkFeat/trainer_single.py new file mode 100644 index 0000000000000000000000000000000000000000..65566e7e27cfd605eba000d308b6d3610f29e746 --- /dev/null +++ b/imcui/third_party/DarkFeat/trainer_single.py @@ -0,0 +1,294 @@ +import os +import cv2 +import time +import yaml +import torch +import datetime +from tensorboardX import SummaryWriter +import torchvision.transforms as tvf +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from nets.geom import getK, getWarp, _grid_positions, getWarpNoValidate +from nets.loss import make_detector_loss +from nets.score import extract_kpts +from nets.sampler import NghSampler2 +from nets.reliability_loss import ReliabilityLoss +from datasets.noise_simulator import NoiseSimulator +from nets.l2net import Quad_L2Net + + +class SingleTrainer: + def __init__(self, config, device, loader, job_name, start_cnt): + self.config = config + self.device = device + self.loader = loader + + # tensorboard writer construction + os.makedirs('./runs/', exist_ok=True) + if job_name != '': + self.log_dir = f'runs/{job_name}' + else: + self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}' + + self.writer = SummaryWriter(self.log_dir) + with open(f'{self.log_dir}/config.yaml', 'w') as f: + yaml.dump(config, f) + + if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray': + self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device) + elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic': + self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device) + elif config['network']['input_type'] == 'raw': + self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device) + else: + raise NotImplementedError() + + # noise maker + self.noise_maker = NoiseSimulator(device) + + # load model + self.cnt = 0 + if start_cnt != 0: + self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth')) + self.cnt = start_cnt + 1 + + # sampler + sampler = NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, + subd_neg=-8,maxpool_pos=True).to(device) + self.reliability_loss = ReliabilityLoss(sampler, base=0.3, nq=20).to(device) + # reliability map conv + self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda() + + # optimizer and scheduler + if self.config['training']['optimizer'] == 'SGD': + self.optimizer = torch.optim.SGD( + [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], + lr=self.config['training']['lr'], + momentum=self.config['training']['momentum'], + weight_decay=self.config['training']['weight_decay'], + ) + elif self.config['training']['optimizer'] == 'Adam': + self.optimizer = torch.optim.Adam( + [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], + lr=self.config['training']['lr'], + weight_decay=self.config['training']['weight_decay'] + ) + else: + raise NotImplementedError() + + self.lr_scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, + step_size=self.config['training']['lr_step'], + gamma=self.config['training']['lr_gamma'], + last_epoch=start_cnt + ) + for param_tensor in self.model.state_dict(): + print(param_tensor, "\t", self.model.state_dict()[param_tensor].size()) + + + def save(self, iter_num): + torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth') + + def load(self, path): + self.model.load_state_dict(torch.load(path)) + + def train(self): + self.model.train() + + for epoch in range(2): + for batch_idx, inputs in enumerate(self.loader): + self.optimizer.zero_grad() + t = time.time() + + # preprocess and add noise + img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt) + img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt) + + img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device) + img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device) + + if self.config['network']['input_type'] == 'rgb': + # 3-channel rgb + RGB_mean = [0.485, 0.456, 0.406] + RGB_std = [0.229, 0.224, 0.225] + norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) + img0 = norm_RGB(img0) + img1 = norm_RGB(img1) + noise_img0 = norm_RGB(noise_img0) + noise_img1 = norm_RGB(noise_img1) + + elif self.config['network']['input_type'] == 'gray': + # 1-channel + img0 = torch.mean(img0, dim=1, keepdim=True) + img1 = torch.mean(img1, dim=1, keepdim=True) + noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True) + noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True) + norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std()) + norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std()) + img0 = norm_gray0(img0) + img1 = norm_gray1(img1) + noise_img0 = norm_gray0(noise_img0) + noise_img1 = norm_gray1(noise_img1) + + elif self.config['network']['input_type'] == 'raw': + # 4-channel + pass + + elif self.config['network']['input_type'] == 'raw-demosaic': + # 3-channel + pass + + else: + raise NotImplementedError() + + desc0, score_map0, _, _ = self.model(img0) + desc1, score_map1, _, _ = self.model(img1) + + cur_feat_size0 = torch.tensor(score_map0.shape[2:]) + cur_feat_size1 = torch.tensor(score_map1.shape[2:]) + + conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2] + conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2] + + desc0 = desc0.permute(0, 2, 3, 1) + desc1 = desc1.permute(0, 2, 3, 1) + score_map0 = score_map0.permute(0, 2, 3, 1) + score_map1 = score_map1.permute(0, 2, 3, 1) + conf0 = conf0.permute(0, 2, 3, 1) + conf1 = conf1.permute(0, 2, 3, 1) + + r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device) + r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device) + + pos0 = _grid_positions( + cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device) + + pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate( + pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), + r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + + pos0, pos1, _ = getWarp( + pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), + r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + + reliab_loss = self.reliability_loss(desc0, desc1, conf0, conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3]) + + det_structured_loss, det_accuracy = make_detector_loss( + pos0, pos1, desc0, desc1, + score_map0, score_map1, img0.shape[0], + self.config['network']['use_corr_n'], + self.config['network']['loss_type'], + self.config + ) + + total_loss = det_structured_loss + self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt) + + total_loss += reliab_loss + + self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt) + self.writer.add_scalar("loss/total_loss", total_loss, self.cnt) + self.writer.add_scalar("loss/reliab_loss", reliab_loss, self.cnt) + print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t)) + + if det_structured_loss != 0: + total_loss.backward() + self.optimizer.step() + self.lr_scheduler.step() + + if self.cnt % 100 == 0: + indices0, scores0 = extract_kpts( + score_map0.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + indices1, scores1 = extract_kpts( + score_map1.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + + if self.config['network']['input_type'] == 'raw': + kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0]) + else: + kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0]) + + self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC') + self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC') + self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC') + self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC') + self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC') + self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC') + + if self.cnt % 10000 == 0: + self.save(self.cnt) + + self.cnt += 1 + + + def showKeyPoints(self, img, indices): + key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1]) + img = img.numpy().astype('uint8') + img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0)) + return img + + + def preprocess(self, img, iter_idx): + if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']: + return img + + raw = self.noise_maker.rgb2raw(img, batched=True) + + if self.config['network']['noise']: + ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + + if self.config['network']['input_type'] == 'raw': + return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)) + + if self.config['network']['input_type'] == 'raw-demosaic': + return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)) + + rgb = self.noise_maker.raw2rgb(raw, batched=True) + if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + return torch.tensor(rgb) + + raise NotImplementedError() + + + def preprocess_noise_pair(self, img, iter_idx): + assert self.config['network']['noise'] + + raw = self.noise_maker.rgb2raw(img, batched=True) + + ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + + if self.config['network']['input_type'] == 'raw': + return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \ + torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) + + if self.config['network']['input_type'] == 'raw-demosaic': + return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \ + torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) + + if self.config['network']['input_type'] == 'raw-gray': + factor = torch.tensor([0.299, 0.587, 0.114]).double() + return torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), factor).unsqueeze(-1), \ + torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)), factor).unsqueeze(-1) + + noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True) + if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + return img, torch.tensor(noise_rgb) + + raise NotImplementedError() diff --git a/imcui/third_party/DarkFeat/trainer_single_norel.py b/imcui/third_party/DarkFeat/trainer_single_norel.py new file mode 100644 index 0000000000000000000000000000000000000000..a572e9c599adc30e5753e11e668d121cd378672a --- /dev/null +++ b/imcui/third_party/DarkFeat/trainer_single_norel.py @@ -0,0 +1,265 @@ +import os +import cv2 +import time +import yaml +import torch +import datetime +from tensorboardX import SummaryWriter +import torchvision.transforms as tvf +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from nets.l2net import Quad_L2Net +from nets.geom import getK, getWarp, _grid_positions +from nets.loss import make_detector_loss +from nets.score import extract_kpts +from datasets.noise_simulator import NoiseSimulator +from nets.l2net import Quad_L2Net + + +class SingleTrainerNoRel: + def __init__(self, config, device, loader, job_name, start_cnt): + self.config = config + self.device = device + self.loader = loader + + # tensorboard writer construction + os.makedirs('./runs/', exist_ok=True) + if job_name != '': + self.log_dir = f'runs/{job_name}' + else: + self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}' + + self.writer = SummaryWriter(self.log_dir) + with open(f'{self.log_dir}/config.yaml', 'w') as f: + yaml.dump(config, f) + + if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray': + self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device) + elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic': + self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device) + elif config['network']['input_type'] == 'raw': + self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device) + else: + raise NotImplementedError() + + # noise maker + self.noise_maker = NoiseSimulator(device) + + # load model + self.cnt = 0 + if start_cnt != 0: + self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth')) + self.cnt = start_cnt + 1 + + # optimizer and scheduler + if self.config['training']['optimizer'] == 'SGD': + self.optimizer = torch.optim.SGD( + [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], + lr=self.config['training']['lr'], + momentum=self.config['training']['momentum'], + weight_decay=self.config['training']['weight_decay'], + ) + elif self.config['training']['optimizer'] == 'Adam': + self.optimizer = torch.optim.Adam( + [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}], + lr=self.config['training']['lr'], + weight_decay=self.config['training']['weight_decay'] + ) + else: + raise NotImplementedError() + + self.lr_scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, + step_size=self.config['training']['lr_step'], + gamma=self.config['training']['lr_gamma'], + last_epoch=start_cnt + ) + for param_tensor in self.model.state_dict(): + print(param_tensor, "\t", self.model.state_dict()[param_tensor].size()) + + + def save(self, iter_num): + torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth') + + def load(self, path): + self.model.load_state_dict(torch.load(path)) + + def train(self): + self.model.train() + + for epoch in range(2): + for batch_idx, inputs in enumerate(self.loader): + self.optimizer.zero_grad() + t = time.time() + + # preprocess and add noise + img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt) + img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt) + + img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device) + img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device) + + if self.config['network']['input_type'] == 'rgb': + # 3-channel rgb + RGB_mean = [0.485, 0.456, 0.406] + RGB_std = [0.229, 0.224, 0.225] + norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) + img0 = norm_RGB(img0) + img1 = norm_RGB(img1) + noise_img0 = norm_RGB(noise_img0) + noise_img1 = norm_RGB(noise_img1) + + elif self.config['network']['input_type'] == 'gray': + # 1-channel + img0 = torch.mean(img0, dim=1, keepdim=True) + img1 = torch.mean(img1, dim=1, keepdim=True) + noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True) + noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True) + norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std()) + norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std()) + img0 = norm_gray0(img0) + img1 = norm_gray1(img1) + noise_img0 = norm_gray0(noise_img0) + noise_img1 = norm_gray1(noise_img1) + + elif self.config['network']['input_type'] == 'raw': + # 4-channel + pass + + elif self.config['network']['input_type'] == 'raw-demosaic': + # 3-channel + pass + + else: + raise NotImplementedError() + + desc0, score_map0, _, _ = self.model(img0) + desc1, score_map1, _, _ = self.model(img1) + + cur_feat_size0 = torch.tensor(score_map0.shape[2:]) + cur_feat_size1 = torch.tensor(score_map1.shape[2:]) + + desc0 = desc0.permute(0, 2, 3, 1) + desc1 = desc1.permute(0, 2, 3, 1) + score_map0 = score_map0.permute(0, 2, 3, 1) + score_map1 = score_map1.permute(0, 2, 3, 1) + + r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device) + r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device) + + pos0 = _grid_positions( + cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device) + + pos0, pos1, _ = getWarp( + pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device), + r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0]) + + det_structured_loss, det_accuracy = make_detector_loss( + pos0, pos1, desc0, desc1, + score_map0, score_map1, img0.shape[0], + self.config['network']['use_corr_n'], + self.config['network']['loss_type'], + self.config + ) + + total_loss = det_structured_loss + + self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt) + self.writer.add_scalar("loss/total_loss", total_loss, self.cnt) + self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt) + print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t)) + + if det_structured_loss != 0: + total_loss.backward() + self.optimizer.step() + self.lr_scheduler.step() + + if self.cnt % 100 == 0: + indices0, scores0 = extract_kpts( + score_map0.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + indices1, scores1 = extract_kpts( + score_map1.permute(0, 3, 1, 2), + k=self.config['network']['det']['kpt_n'], + score_thld=self.config['network']['det']['score_thld'], + nms_size=self.config['network']['det']['nms_size'], + eof_size=self.config['network']['det']['eof_size'], + edge_thld=self.config['network']['det']['edge_thld'] + ) + + if self.config['network']['input_type'] == 'raw': + kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0]) + else: + kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0]) + kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0]) + + self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC') + self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC') + self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC') + self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC') + + if self.cnt % 10000 == 0: + self.save(self.cnt) + + self.cnt += 1 + + + def showKeyPoints(self, img, indices): + key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1]) + img = img.numpy().astype('uint8') + img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0)) + return img + + + def preprocess(self, img, iter_idx): + if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']: + return img + + raw = self.noise_maker.rgb2raw(img, batched=True) + + if self.config['network']['noise']: + ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + + if self.config['network']['input_type'] == 'raw': + return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)) + + if self.config['network']['input_type'] == 'raw-demosaic': + return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)) + + rgb = self.noise_maker.raw2rgb(raw, batched=True) + if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + return torch.tensor(rgb) + + raise NotImplementedError() + + + def preprocess_noise_pair(self, img, iter_idx): + assert self.config['network']['noise'] + + raw = self.noise_maker.rgb2raw(img, batched=True) + + ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep'] + noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True) + + if self.config['network']['input_type'] == 'raw': + return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \ + torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True)) + + if self.config['network']['input_type'] == 'raw-demosaic': + return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \ + torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)) + + noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True) + if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray': + return img, torch.tensor(noise_rgb) + + raise NotImplementedError() diff --git a/imcui/third_party/DarkFeat/utils/__init__.py b/imcui/third_party/DarkFeat/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DarkFeat/utils/matching.py b/imcui/third_party/DarkFeat/utils/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..ca091f418bb4dc4d278611e5126a930aa51e7f3f --- /dev/null +++ b/imcui/third_party/DarkFeat/utils/matching.py @@ -0,0 +1,128 @@ +import math +import numpy as np +import cv2 + +def extract_ORB_keypoints_and_descriptors(img): + # gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + detector = cv2.ORB_create(nfeatures=1000) + kp, desc = detector.detectAndCompute(img, None) + return kp, desc + +def match_descriptors_NG(kp1, desc1, kp2, desc2): + bf = cv2.BFMatcher() + try: + matches = bf.knnMatch(desc1, desc2,k=2) + except: + matches = [] + good_matches=[] + image1_kp = [] + image2_kp = [] + ratios = [] + try: + for (m1,m2) in matches: + if m1.distance < 0.8 * m2.distance: + good_matches.append(m1) + image2_kp.append(kp2[m1.trainIdx].pt) + image1_kp.append(kp1[m1.queryIdx].pt) + ratios.append(m1.distance / m2.distance) + except: + pass + image1_kp = np.array([image1_kp]) + image2_kp = np.array([image2_kp]) + ratios = np.array([ratios]) + ratios = np.expand_dims(ratios, 2) + return image1_kp, image2_kp, good_matches, ratios + +def match_descriptors(kp1, desc1, kp2, desc2, ORB): + if ORB: + bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) + try: + matches = bf.match(desc1,desc2) + matches = sorted(matches, key = lambda x:x.distance) + except: + matches = [] + good_matches=[] + image1_kp = [] + image2_kp = [] + count = 0 + try: + for m in matches: + count+=1 + if count < 1000: + good_matches.append(m) + image2_kp.append(kp2[m.trainIdx].pt) + image1_kp.append(kp1[m.queryIdx].pt) + except: + pass + else: + # Match the keypoints with the warped_keypoints with nearest neighbor search + bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) + try: + matches = bf.match(desc1.transpose(1,0), desc2.transpose(1,0)) + matches = sorted(matches, key = lambda x:x.distance) + except: + matches = [] + good_matches=[] + image1_kp = [] + image2_kp = [] + try: + for m in matches: + good_matches.append(m) + image2_kp.append(kp2[m.trainIdx].pt) + image1_kp.append(kp1[m.queryIdx].pt) + except: + pass + + image1_kp = np.array([image1_kp]) + image2_kp = np.array([image2_kp]) + return image1_kp, image2_kp, good_matches + + +def compute_essential(matched_kp1, matched_kp2, K): + pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0)) + K_1 = np.eye(3) + # Estimate the homography between the matches using RANSAC + ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.FM_RANSAC, prob=0.999, threshold=0.001) + if ransac_inliers is None or ransac_model.shape != (3,3): + ransac_inliers = np.array([]) + ransac_model = None + return ransac_model, ransac_inliers, pts1, pts2 + + +def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers): + """Compute the angular error between two rotation matrices and two translation vectors. + Keyword arguments: + R -- 2D numpy array containing an estimated rotation + gt_R -- 2D numpy array containing the corresponding ground truth rotation + t -- 2D numpy array containing an estimated translation as column + gt_t -- 2D numpy array containing the corresponding ground truth translation + """ + + inliers = inliers.ravel() + R = np.eye(3) + t = np.zeros((3,1)) + sst = True + try: + cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), R, t, inliers) + except: + sst = False + # calculate angle between provided rotations + # + if sst: + dR = np.matmul(R, np.transpose(R_GT)) + dR = cv2.Rodrigues(dR)[0] + dR = np.linalg.norm(dR) * 180 / math.pi + + # calculate angle between provided translations + dT = float(np.dot(t_GT.T, t)) + dT /= float(np.linalg.norm(t_GT)) + + if dT > 1 or dT < -1: + print("Domain warning! dT:",dT) + dT = max(-1,min(1,dT)) + dT = math.acos(dT) * 180 / math.pi + dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation + else: + dR,dT = 180.0, 180.0 + return dR, dT diff --git a/imcui/third_party/DarkFeat/utils/misc.py b/imcui/third_party/DarkFeat/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..1df6fdec97121486dbb94e0b32a2f66c85c48f7d --- /dev/null +++ b/imcui/third_party/DarkFeat/utils/misc.py @@ -0,0 +1,158 @@ +from pathlib import Path +import time +from collections import OrderedDict +import numpy as np +import cv2 +import rawpy +import torch +import colour_demosaicing + + +class AverageTimer: + """ Class to help manage printing simple timing of code execution. """ + + def __init__(self, smoothing=0.3, newline=False): + self.smoothing = smoothing + self.newline = newline + self.times = OrderedDict() + self.will_print = OrderedDict() + self.reset() + + def reset(self): + now = time.time() + self.start = now + self.last_time = now + for name in self.will_print: + self.will_print[name] = False + + def update(self, name='default'): + now = time.time() + dt = now - self.last_time + if name in self.times: + dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name] + self.times[name] = dt + self.will_print[name] = True + self.last_time = now + + def print(self, text='Timer'): + total = 0. + print('[{}]'.format(text), end=' ') + for key in self.times: + val = self.times[key] + if self.will_print[key]: + print('%s=%.3f' % (key, val), end=' ') + total += val + print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ') + if self.newline: + print(flush=True) + else: + print(end='\r', flush=True) + self.reset() + + +class VideoStreamer: + def __init__(self, basedir, resize, image_glob): + self.listing = [] + self.resize = resize + self.i = 0 + if Path(basedir).is_dir(): + print('==> Processing image directory input: {}'.format(basedir)) + self.listing = list(Path(basedir).glob(image_glob[0])) + for j in range(1, len(image_glob)): + image_path = list(Path(basedir).glob(image_glob[j])) + self.listing = self.listing + image_path + self.listing.sort() + if len(self.listing) == 0: + raise IOError('No images found (maybe bad \'image_glob\' ?)') + self.max_length = len(self.listing) + else: + raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir)) + + def load_image(self, impath): + raw = rawpy.imread(str(impath)).raw_image_visible + raw = np.clip(raw.astype('float32') - 512, 0, 65535) + img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32') + img = np.clip(img, 0, 16383) + + m = img.mean() + d = np.abs(img - img.mean()).mean() + img = (img - m + 2*d) / 4/d * 255 + image = np.clip(img, 0, 255) + + w_new, h_new = self.resize[0], self.resize[1] + + im = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA) + return im + + def next_frame(self): + if self.i == self.max_length: + return (None, False) + image_file = str(self.listing[self.i]) + image = self.load_image(image_file) + self.i = self.i + 1 + return (image, True) + + +def frame2tensor(frame, device): + if len(frame.shape) == 2: + return torch.from_numpy(frame/255.).float()[None, None].to(device) + else: + return torch.from_numpy(frame/255.).float().permute(2, 0, 1)[None].to(device) + + +def make_matching_plot_fast(image0, image1, mkpts0, mkpts1, + color, text, path=None, margin=10, + opencv_display=False, opencv_title='', + small_text=[]): + H0, W0 = image0.shape[:2] + H1, W1 = image1.shape[:2] + H, W = max(H0, H1), W0 + W1 + margin + + out = 255*np.ones((H, W, 3), np.uint8) + out[:H0, :W0, :] = image0 + out[:H1, W0+margin:, :] = image1 + + # Scale factor for consistent visualization across scales. + sc = min(H / 640., 2.0) + + # Big text. + Ht = int(30 * sc) # text height + txt_color_fg = (255, 255, 255) + txt_color_bg = (0, 0, 0) + + for i, t in enumerate(text): + cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, + 1.0*sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, + 1.0*sc, txt_color_fg, 1, cv2.LINE_AA) + + out_backup = out.copy() + + mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) + color = (np.array(color[:, :3])*255).astype(int)[:, ::-1] + for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color): + c = c.tolist() + cv2.line(out, (x0, y0), (x1 + margin + W0, y1), + color=c, thickness=1, lineType=cv2.LINE_AA) + # display line end-points as circles + cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, + lineType=cv2.LINE_AA) + + # Small text. + Ht = int(18 * sc) # text height + for i, t in enumerate(reversed(small_text)): + cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, + 0.5*sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, + 0.5*sc, txt_color_fg, 1, cv2.LINE_AA) + + if path is not None: + cv2.imwrite(str(path), out) + + if opencv_display: + cv2.imshow(opencv_title, out) + cv2.waitKey(1) + + return out / 2 + out_backup / 2 + diff --git a/imcui/third_party/DarkFeat/utils/nn.py b/imcui/third_party/DarkFeat/utils/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..8a80631d6e12d848cceee3b636baf49deaa7647a --- /dev/null +++ b/imcui/third_party/DarkFeat/utils/nn.py @@ -0,0 +1,50 @@ +import torch +from torch import nn + + +class NN2(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, data): + desc1, desc2 = data['descriptors0'].cuda(), data['descriptors1'].cuda() + kpts1, kpts2 = data['keypoints0'].cuda(), data['keypoints1'].cuda() + + # torch.cuda.synchronize() + # t = time.time() + + if kpts1.shape[1] <= 1 or kpts2.shape[1] <= 1: # no keypoints + shape0, shape1 = kpts1.shape[:-1], kpts2.shape[:-1] + return { + 'matches0': kpts1.new_full(shape0, -1, dtype=torch.int), + 'matches1': kpts2.new_full(shape1, -1, dtype=torch.int), + 'matching_scores0': kpts1.new_zeros(shape0), + 'matching_scores1': kpts2.new_zeros(shape1), + } + + sim = torch.matmul(desc1.squeeze().T, desc2.squeeze()) + ids1 = torch.arange(0, sim.shape[0], device=desc1.device) + nn12 = torch.argmax(sim, dim=1) + + nn21 = torch.argmax(sim, dim=0) + mask = torch.eq(ids1, nn21[nn12]) + matches = torch.stack([torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)]) + # matches = torch.stack([ids1, nn12]) + indices0 = torch.ones((1, desc1.shape[-1]), dtype=int) * -1 + mscores0 = torch.ones((1, desc1.shape[-1]), dtype=float) * -1 + + # torch.cuda.synchronize() + # print(time.time() - t) + + matches_0 = matches[0].cpu().int().numpy() + matches_1 = matches[1].cpu().int() + for i in range(matches.shape[-1]): + indices0[0, matches_0[i]] = matches_1[i].int() + mscores0[0, matches_0[i]] = sim[matches_0[i], matches_1[i]] + + return { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices0, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores0, + } diff --git a/imcui/third_party/DarkFeat/utils/nnmatching.py b/imcui/third_party/DarkFeat/utils/nnmatching.py new file mode 100644 index 0000000000000000000000000000000000000000..7be6f98c050fc2e416ef48e25ca0f293106c1082 --- /dev/null +++ b/imcui/third_party/DarkFeat/utils/nnmatching.py @@ -0,0 +1,41 @@ +import torch + +from .nn import NN2 +from darkfeat import DarkFeat + +class NNMatching(torch.nn.Module): + def __init__(self, model_path=''): + super().__init__() + self.nn = NN2().eval() + self.darkfeat = DarkFeat(model_path).eval() + + def forward(self, data): + """ Run DarkFeat and nearest neighborhood matching + Args: + data: dictionary with minimal keys: ['image0', 'image1'] + """ + pred = {} + + # Extract DarkFeat (keypoints, scores, descriptors) + if 'keypoints0' not in data: + pred0 = self.darkfeat({'image': data['image0']}) + # print({k+'0': v[0].shape for k, v in pred0.items()}) + pred = {**pred, **{k+'0': [v] for k, v in pred0.items()}} + if 'keypoints1' not in data: + pred1 = self.darkfeat({'image': data['image1']}) + pred = {**pred, **{k+'1': [v] for k, v in pred1.items()}} + + + # Batch all features + # We should either have i) one image per batch, or + # ii) the same number of local features for all images in the batch. + data = {**data, **pred} + + for k in data: + if isinstance(data[k], (list, tuple)): + data[k] = torch.stack(data[k]) + + # Perform the matching + pred = {**pred, **self.nn(data)} + + return pred diff --git a/imcui/third_party/DeDoDe/DeDoDe/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9716b62f0672cfc604ca95280d8aa51a04944d4f --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/__init__.py @@ -0,0 +1,2 @@ +from .model_zoo import dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G +DEBUG_MODE = False diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06d86ba8d4e509dae88e7f5297407a542d9a8774 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py @@ -0,0 +1,4 @@ +from .num_inliers import NumInliersBenchmark +from .mega_pose_est import MegaDepthPoseEstimationBenchmark +from .mega_pose_est_mnn import MegaDepthPoseMNNBenchmark +from .nll_benchmark import MegadepthNLLBenchmark \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py new file mode 100644 index 0000000000000000000000000000000000000000..2104284b54d5fe339d6f12d9ae14dcdd3c0fb564 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py @@ -0,0 +1,114 @@ +import numpy as np +import torch +from DeDoDe.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F + +class MegaDepthPoseEstimationBenchmark: + def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, keypoint_model, matching_model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True): + H,W = matching_model.get_output_resolution() + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/{im_paths[idx1]}" + im_B_path = f"{data_root}/{im_paths[idx2]}" + + keypoints_A = keypoint_model.detect_from_path(im_A_path, num_keypoints = 20_000)["keypoints"][0] + keypoints_B = keypoint_model.detect_from_path(im_B_path, num_keypoints = 20_000)["keypoints"][0] + warp, certainty = matching_model.match(im_A_path, im_B_path) + matches = matching_model.match_keypoints(keypoints_A, keypoints_B, warp, certainty, return_tuple = False) + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + if scale_intrinsics: + scale1 = 1200 / max(w1, h1) + scale2 = 1200 / max(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1, K2 = K1.copy(), K2.copy() + K1[:2] = K1[:2] * scale1 + K2[:2] = K2[:2] * scale2 + kpts1, kpts2 = matching_model.to_pixel_coordinates(matches, h1, w1, h2, w2) + for _ in range(1): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + threshold = 0.5 + if calibrated: + norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1.cpu().numpy(), + kpts2.cpu().numpy(), + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py new file mode 100644 index 0000000000000000000000000000000000000000..d717a09701889fdae42eb7aba7050025ad7c6c52 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py @@ -0,0 +1,119 @@ +import numpy as np +import torch +from DeDoDe.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F + +class MegaDepthPoseMNNBenchmark: + def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, detector_model, descriptor_model, matcher_model, model_name = None, resolution = None, scale_intrinsics = False, calibrated = True): + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/{im_paths[idx1]}" + im_B_path = f"{data_root}/{im_paths[idx2]}" + detections_A = detector_model.detect_from_path(im_A_path) + keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"] + detections_B = detector_model.detect_from_path(im_B_path) + keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"] + description_A = descriptor_model.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"] + description_B = descriptor_model.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"] + matches_A, matches_B, batch_ids = matcher_model.match(keypoints_A, description_A, + keypoints_B, description_B, + P_A = P_A, P_B = P_B, + normalize = True, inv_temp=20, threshold = 0.01) + + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + if scale_intrinsics: + scale1 = 840 / max(w1, h1) + scale2 = 840 / max(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1, K2 = K1.copy(), K2.copy() + K1[:2] = K1[:2] * scale1 + K2[:2] = K2[:2] * scale2 + kpts1, kpts2 = matcher_model.to_pixel_coords(matches_A, matches_B, h1, w1, h2, w2) + for _ in range(1): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + threshold = 0.5 + if calibrated: + norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1.cpu().numpy(), + kpts2.cpu().numpy(), + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..d64103708919594bf8d297d92a908afb79f48002 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from DeDoDe.utils import * +import DeDoDe + +class MegadepthNLLBenchmark(nn.Module): + + def __init__(self, dataset, num_samples = 1000, batch_size = 8, device = "cuda") -> None: + super().__init__() + sampler = torch.utils.data.WeightedRandomSampler( + torch.ones(len(dataset)), replacement=False, num_samples=num_samples + ) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler + ) + self.dataloader = dataloader + self.tracked_metrics = {} + self.batch_size = batch_size + self.N = len(dataloader) + + def compute_batch_metrics(self, detector, descriptor, batch, device = "cuda"): + kpts = detector.detect(batch)["keypoints"] + descriptions_A, descriptions_B = descriptor.describe_keypoints(batch, kpts)["descriptions"].chunk(2) + kpts_A, kpts_B = kpts.chunk(2) + mask_A_to_B, kpts_A_to_B = warp_kpts(kpts_A, + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"],) + mask_B_to_A, kpts_B_to_A = warp_kpts(kpts_B, + batch["im_B_depth"], + batch["im_A_depth"], + batch["T_1to2"].inverse(), + batch["K2"], + batch["K1"],) + with torch.no_grad(): + D_B = torch.cdist(kpts_A_to_B, kpts_B) + D_A = torch.cdist(kpts_A, kpts_B_to_A) + inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values) + * (D_A == D_A.min(dim=-2, keepdim = True).values) + * (D_B < 0.01) + * (D_A < 0.01)) + logP_A_B = dual_log_softmax_matcher(descriptions_A, descriptions_B, + normalize = True, + inv_temperature = 20) + neg_log_likelihood = -logP_A_B[inds[:,0], inds[:,1], inds[:,2]].mean() + self.tracked_metrics["neg_log_likelihood"] = self.tracked_metrics.get("neg_log_likelihood", 0) + 1/self.N * neg_log_likelihood + + def benchmark(self, detector, descriptor): + self.tracked_metrics = {} + from tqdm import tqdm + print("Evaluating percent inliers...") + for idx, batch in tqdm(enumerate(self.dataloader), mininterval = 10.): + batch = to_cuda(batch) + self.compute_batch_metrics(detector, descriptor, batch) + [print(name, metric.item() * self.N / (idx+1)) for name, metric in self.tracked_metrics.items()] \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3a5869d8ff15ff4d0b300da8259a99e38c5cf2 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +from DeDoDe.utils import * +import DeDoDe + +class NumInliersBenchmark(nn.Module): + + def __init__(self, dataset, num_samples = 1000, batch_size = 8, num_keypoints = 10_000, device = get_best_device()) -> None: + super().__init__() + sampler = torch.utils.data.WeightedRandomSampler( + torch.ones(len(dataset)), replacement=False, num_samples=num_samples + ) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler + ) + self.dataloader = dataloader + self.tracked_metrics = {} + self.batch_size = batch_size + self.N = len(dataloader) + self.num_keypoints = num_keypoints + + def compute_batch_metrics(self, outputs, batch, device = get_best_device()): + kpts_A, kpts_B = outputs["keypoints_A"], outputs["keypoints_B"] + B, K, H, W = batch["im_A"].shape + gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=H, + W=W, + ) + kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:], + align_corners=False, mode = 'bilinear')[...,0].mT + legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:], + align_corners=False, mode = 'bilinear')[...,0,:,0] + dists = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.]).float() + if legit_A_to_B.sum() == 0: + return + percent_inliers_at_1 = (dists < 0.02).float().mean() + percent_inliers_at_05 = (dists < 0.01).float().mean() + percent_inliers_at_025 = (dists < 0.005).float().mean() + percent_inliers_at_01 = (dists < 0.002).float().mean() + percent_inliers_at_005 = (dists < 0.001).float().mean() + + inlier_bins = torch.linspace(0, 0.002, steps = 100, device = device)[None] + inlier_counts = (dists[...,None] < inlier_bins).float().mean(dim=0) + self.tracked_metrics["inlier_counts"] = self.tracked_metrics.get("inlier_counts", 0) + 1/self.N * inlier_counts + self.tracked_metrics["percent_inliers_at_1"] = self.tracked_metrics.get("percent_inliers_at_1", 0) + 1/self.N * percent_inliers_at_1 + self.tracked_metrics["percent_inliers_at_05"] = self.tracked_metrics.get("percent_inliers_at_05", 0) + 1/self.N * percent_inliers_at_05 + self.tracked_metrics["percent_inliers_at_025"] = self.tracked_metrics.get("percent_inliers_at_025", 0) + 1/self.N * percent_inliers_at_025 + self.tracked_metrics["percent_inliers_at_01"] = self.tracked_metrics.get("percent_inliers_at_01", 0) + 1/self.N * percent_inliers_at_01 + self.tracked_metrics["percent_inliers_at_005"] = self.tracked_metrics.get("percent_inliers_at_005", 0) + 1/self.N * percent_inliers_at_005 + + def benchmark(self, detector): + self.tracked_metrics = {} + from tqdm import tqdm + print("Evaluating percent inliers...") + for idx, batch in tqdm(enumerate(self.dataloader), mininterval = 10.): + batch = to_best_device(batch) + outputs = detector.detect(batch, num_keypoints = self.num_keypoints) + keypoints_A, keypoints_B = outputs["keypoints"][:self.batch_size], outputs["keypoints"][self.batch_size:] + if isinstance(outputs["keypoints"], (tuple, list)): + keypoints_A, keypoints_B = torch.stack(keypoints_A), torch.stack(keypoints_B) + outputs = {"keypoints_A": keypoints_A, "keypoints_B": keypoints_B} + self.compute_batch_metrics(outputs, batch) + import matplotlib.pyplot as plt + plt.plot(torch.linspace(0, 0.002, steps = 100), self.tracked_metrics["inlier_counts"].cpu()) + import numpy as np + x = np.linspace(0,0.002, 100) + sigma = 0.52 * 2 / 512 + F = 1 - np.exp(-x**2 / (2*sigma**2)) + plt.plot(x, F) + plt.savefig("vis/inlier_counts") + [print(name, metric.item() * self.N / (idx+1)) for name, metric in self.tracked_metrics.items() if "percent" in name] \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/checkpoint.py b/imcui/third_party/DeDoDe/DeDoDe/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..07d6f80ae09acf5702475504a8e8d61f40c21cd3 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/checkpoint.py @@ -0,0 +1,59 @@ +import os +import torch +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +import gc + +import DeDoDe + +class CheckPoint: + def __init__(self, dir=None, name="tmp"): + self.name = name + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + + def save( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if DeDoDe.RANK == 0: + assert model is not None + if isinstance(model, (DataParallel, DistributedDataParallel)): + model = model.module + states = { + "model": model.state_dict(), + "n": n, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + } + torch.save(states, self.dir + self.name + f"_latest.pth") + print(f"Saved states {list(states.keys())}, at step {n}") + + def load( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if os.path.exists(self.dir + self.name + f"_latest.pth") and DeDoDe.RANK == 0: + states = torch.load(self.dir + self.name + f"_latest.pth") + if "model" in states: + model.load_state_dict(states["model"]) + if "n" in states: + n = states["n"] if states["n"] else n + if "optimizer" in states: + try: + optimizer.load_state_dict(states["optimizer"]) + except Exception as e: + print(f"Failed to load states for optimizer, with error {e}") + if "lr_scheduler" in states: + lr_scheduler.load_state_dict(states["lr_scheduler"]) + print(f"Loaded states {list(states.keys())}, at step {n}") + del states + gc.collect() + torch.cuda.empty_cache() + return model, optimizer, lr_scheduler, n \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/datasets/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DeDoDe/DeDoDe/datasets/megadepth.py b/imcui/third_party/DeDoDe/DeDoDe/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..7de9d9a8e270fb74a6591944878c0e5e70ddf650 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/datasets/megadepth.py @@ -0,0 +1,269 @@ +import os +from PIL import Image +import h5py +import numpy as np +import torch +import torchvision.transforms.functional as tvf +from tqdm import tqdm + +from DeDoDe.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +import DeDoDe +from DeDoDe.utils import * + +class MegadepthScene: + def __init__( + self, + data_root, + scene_info, + ht=512, + wt=512, + min_overlap=0.0, + max_overlap=1.0, + shake_t=0, + scene_info_detections=None, + scene_info_detections3D=None, + normalize=True, + max_num_pairs = 100_000, + scene_name = None, + use_horizontal_flip_aug = False, + grayscale = False, + clahe = False, + ) -> None: + self.data_root = data_root + self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}" + self.image_paths = scene_info["image_paths"] + self.depth_paths = scene_info["depth_paths"] + self.intrinsics = scene_info["intrinsics"] + self.poses = scene_info["poses"] + self.pairs = scene_info["pairs"] + self.overlaps = scene_info["overlaps"] + threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap) + self.pairs = self.pairs[threshold] + self.overlaps = self.overlaps[threshold] + self.detections = scene_info_detections + self.tracks3D = scene_info_detections3D + if len(self.pairs) > max_num_pairs: + pairinds = np.random.choice( + np.arange(0, len(self.pairs)), max_num_pairs, replace=False + ) + self.pairs = self.pairs[pairinds] + self.overlaps = self.overlaps[pairinds] + self.im_transform_ops = get_tuple_transform_ops( + resize=(ht, wt), normalize=normalize, clahe = clahe, + ) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt), normalize=False + ) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.use_horizontal_flip_aug = use_horizontal_flip_aug + self.grayscale = grayscale + + def load_im(self, im_B, crop=None): + im = Image.open(im_B) + return im + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + im_A = im_A.flip(-1) + im_B = im_B.flip(-1) + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) + K_A = flip_mat@K_A + K_B = flip_mat@K_B + + return im_A, im_B, depth_A, depth_B, K_A, K_B + + def load_depth(self, depth_ref, crop=None): + depth = np.array(h5py.File(depth_ref, "r")["depth"]) + return torch.from_numpy(depth) + + def __len__(self): + return len(self.pairs) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K + + def scale_detections(self, detections, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + return detections * torch.tensor([[sx,sy]]) + + def rand_shake(self, *things): + t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=(2)) + return [ + tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0]) + for thing in things + ], t + + def tracks_to_detections(self, tracks3D, pose, intrinsics, H, W): + tracks3D = tracks3D.double() + intrinsics = intrinsics.double() + bearing_vectors = pose[...,:3,:3] @ tracks3D.mT + pose[...,:3,3:] + hom_pixel_coords = (intrinsics @ bearing_vectors).mT + pixel_coords = hom_pixel_coords[...,:2] / (hom_pixel_coords[...,2:]+1e-12) + legit_detections = (pixel_coords > 0).prod(dim = -1) * (pixel_coords[...,0] < W - 1) * (pixel_coords[...,1] < H - 1) * (tracks3D != 0).prod(dim=-1) + return pixel_coords.float(), legit_detections.bool() + + def __getitem__(self, pair_idx): + try: + # read intrinsics of original size + idx1, idx2 = self.pairs[pair_idx] + K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3) + K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T1 = self.poses[idx1] + T2 = self.poses[idx2] + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) + + # Load positive pair data + im_A, im_B = self.image_paths[idx1], self.image_paths[idx2] + depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2] + im_A_ref = os.path.join(self.data_root, im_A) + im_B_ref = os.path.join(self.data_root, im_B) + depth_A_ref = os.path.join(self.data_root, depth1) + depth_B_ref = os.path.join(self.data_root, depth2) + # return torch.randn((1000,1000)) + im_A = self.load_im(im_A_ref) + im_B = self.load_im(im_B_ref) + depth_A = self.load_depth(depth_A_ref) + depth_B = self.load_depth(depth_B_ref) + + # Recompute camera intrinsic matrix due to the resize + W_A, H_A = im_A.width, im_A.height + W_B, H_B = im_B.width, im_B.height + + detections2D_A = self.detections[idx1] + detections2D_B = self.detections[idx2] + + K = 10000 + tracks3D_A = torch.zeros(K,3) + tracks3D_B = torch.zeros(K,3) + tracks3D_A[:len(detections2D_A)] = torch.tensor(self.tracks3D[detections2D_A[:K,-1].astype(np.int32)]) + tracks3D_B[:len(detections2D_B)] = torch.tensor(self.tracks3D[detections2D_B[:K,-1].astype(np.int32)]) + + #projs_A, _ = self.tracks_to_detections(tracks3D_A, T1, K1, W_A, H_A) + #tracks3D_B = torch.zeros(K,2) + + K1 = self.scale_intrinsic(K1, W_A, H_A) + K2 = self.scale_intrinsic(K2, W_B, H_B) + + # Process images + im_A, im_B = self.im_transform_ops((im_A, im_B)) + depth_A, depth_B = self.depth_transform_ops( + (depth_A[None, None], depth_B[None, None]) + ) + [im_A, depth_A], t_A = self.rand_shake(im_A, depth_A) + [im_B, depth_B], t_B = self.rand_shake(im_B, depth_B) + + detections_A = -torch.ones(K,2) + detections_B = -torch.ones(K,2) + detections_A[:len(self.detections[idx1])] = self.scale_detections(torch.tensor(detections2D_A[:K,:2]), W_A, H_A) + t_A + detections_B[:len(self.detections[idx2])] = self.scale_detections(torch.tensor(detections2D_B[:K,:2]), W_B, H_B) + t_B + + + K1[:2, 2] += t_A + K2[:2, 2] += t_B + + if self.use_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + detections_A[:,0] = W-detections_A + detections_B[:,0] = W-detections_B + + if DeDoDe.DEBUG_MODE: + tensor_to_pil(im_A[0], unnormalize=True).save( + f"vis/im_A.jpg") + tensor_to_pil(im_B[0], unnormalize=True).save( + f"vis/im_B.jpg") + if self.grayscale: + im_A = im_A.mean(dim=-3,keepdim=True) + im_B = im_B.mean(dim=-3,keepdim=True) + data_dict = { + "im_A": im_A, + "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], + "im_B": im_B, + "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0], + "im_A_depth": depth_A[0, 0], + "im_B_depth": depth_B[0, 0], + "pose_A": T1, + "pose_B": T2, + "detections_A": detections_A, + "detections_B": detections_B, + "tracks3D_A": tracks3D_A, + "tracks3D_B": tracks3D_B, + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + "im_A_path": im_A_ref, + "im_B_path": im_B_ref, + } + except Exception as e: + print(e) + print(f"Failed to load image pair {self.pairs[pair_idx]}") + print("Loading a random pair in scene instead") + rand_ind = np.random.choice(range(len(self))) + return self[rand_ind] + return data_dict + + +class MegadepthBuilder: + def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root, "prep_scene_info") + self.all_scenes = os.listdir(self.scene_info_root) + self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] + # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those + self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy']) + self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy']) + self.test_scenes_loftr = ["0015.npy", "0022.npy"] + self.loftr_ignore = loftr_ignore + self.imc21_ignore = imc21_ignore + + def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs): + if split == "train": + scene_names = set(self.all_scenes) - set(self.test_scenes) + elif split == "train_loftr": + scene_names = set(self.all_scenes) - set(self.test_scenes_loftr) + elif split == "test": + scene_names = self.test_scenes + elif split == "test_loftr": + scene_names = self.test_scenes_loftr + elif split == "custom": + scene_names = scene_names + else: + raise ValueError(f"Split {split} not available") + scenes = [] + for scene_name in tqdm(scene_names): + if self.loftr_ignore and scene_name in self.loftr_ignore_scenes: + continue + if self.imc21_ignore and scene_name in self.imc21_scenes: + continue + if ".npy" not in scene_name: + continue + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ).item() + scene_info_detections = np.load( + os.path.join(self.scene_info_root, "detections", f"detections_{scene_name}"), allow_pickle=True + ).item() + scene_info_detections3D = np.load( + os.path.join(self.scene_info_root, "detections3D", f"detections3D_{scene_name}"), allow_pickle=True + ) + + scenes.append( + MegadepthScene( + self.data_root, scene_info, scene_info_detections = scene_info_detections, scene_info_detections3D = scene_info_detections3D, min_overlap=min_overlap,scene_name = scene_name, **kwargs + ) + ) + return scenes + + def weight_scenes(self, concat_dataset, alpha=0.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) + return ws diff --git a/imcui/third_party/DeDoDe/DeDoDe/decoder.py b/imcui/third_party/DeDoDe/DeDoDe/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1b58fcc588e6ee12c591b5f446829a914bc611 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/decoder.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torchvision.models as tvm + + +class Decoder(nn.Module): + def __init__(self, layers, *args, super_resolution = False, num_prototypes = 1, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layers = layers + self.scales = self.layers.keys() + self.super_resolution = super_resolution + self.num_prototypes = num_prototypes + def forward(self, features, context = None, scale = None): + if context is not None: + features = torch.cat((features, context), dim = 1) + stuff = self.layers[scale](features) + logits, context = stuff[:,:self.num_prototypes], stuff[:,self.num_prototypes:] + return logits, context + +class ConvRefiner(nn.Module): + def __init__( + self, + in_dim=6, + hidden_dim=16, + out_dim=2, + dw=True, + kernel_size=5, + hidden_blocks=5, + amp = True, + residual = False, + amp_dtype = torch.float16, + ): + super().__init__() + self.block1 = self.create_block( + in_dim, hidden_dim, dw=False, kernel_size=1, + ) + self.hidden_blocks = nn.Sequential( + *[ + self.create_block( + hidden_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + ) + for hb in range(hidden_blocks) + ] + ) + self.hidden_blocks = self.hidden_blocks + self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) + self.amp = amp + self.amp_dtype = amp_dtype + self.residual = residual + + def create_block( + self, + in_dim, + out_dim, + dw=True, + kernel_size=5, + bias = True, + norm_type = nn.BatchNorm2d, + ): + num_groups = 1 if not dw else in_dim + if dw: + assert ( + out_dim % in_dim == 0 + ), "outdim must be divisible by indim for depthwise" + conv1 = nn.Conv2d( + in_dim, + out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=num_groups, + bias=bias, + ) + norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) + return nn.Sequential(conv1, norm, relu, conv2) + + def forward(self, feats): + b,c,hs,ws = feats.shape + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + x0 = self.block1(feats) + x = self.hidden_blocks(x0) + if self.residual: + x = (x + x0)/1.4 + x = self.out_conv(x) + return x diff --git a/imcui/third_party/DeDoDe/DeDoDe/descriptors/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/descriptors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py b/imcui/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py new file mode 100644 index 0000000000000000000000000000000000000000..47629729f36b96aef4604e05bb99bd59b6ee070c --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py @@ -0,0 +1,50 @@ +import torch +from PIL import Image +import torch.nn as nn +import torchvision.models as tvm +import torch.nn.functional as F +import numpy as np +from DeDoDe.utils import get_best_device + +class DeDoDeDescriptor(nn.Module): + def __init__(self, encoder, decoder, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.encoder = encoder + self.decoder = decoder + import torchvision.transforms as transforms + self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward( + self, + batch, + ): + if "im_A" in batch: + images = torch.cat((batch["im_A"], batch["im_B"])) + else: + images = batch["image"] + features, sizes = self.encoder(images) + descriptor = 0 + context = None + scales = self.decoder.scales + for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)): + delta_descriptor, context = self.decoder(feature_map, scale = scale, context = context) + descriptor = descriptor + delta_descriptor + if idx < len(scales) - 1: + size = sizes[-(idx+2)] + descriptor = F.interpolate(descriptor, size = size, mode = "bilinear", align_corners = False) + context = F.interpolate(context, size = size, mode = "bilinear", align_corners = False) + return {"description_grid" : descriptor} + + @torch.inference_mode() + def describe_keypoints(self, batch, keypoints): + self.train(False) + description_grid = self.forward(batch)["description_grid"] + described_keypoints = F.grid_sample(description_grid.float(), keypoints[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT + return {"descriptions": described_keypoints} + + def read_image(self, im_path, H = 784, W = 784, device=get_best_device()): + return self.normalizer(torch.from_numpy(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1)).float().to(device)[None] + + def describe_keypoints_from_path(self, im_path, keypoints, H = 784, W = 784): + batch = {"image": self.read_image(im_path, H = H, W = W)} + return self.describe_keypoints(batch, keypoints) \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py b/imcui/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7ece7fed2db02ea8ea51b4b5f49391cdcaef0903 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F + +from DeDoDe.utils import * +import DeDoDe + +class DescriptorLoss(nn.Module): + + def __init__(self, detector, num_keypoints = 5000, normalize_descriptions = False, inv_temp = 1, device = get_best_device()) -> None: + super().__init__() + self.detector = detector + self.tracked_metrics = {} + self.num_keypoints = num_keypoints + self.normalize_descriptions = normalize_descriptions + self.inv_temp = inv_temp + + def warp_from_depth(self, batch, kpts_A, kpts_B): + mask_A_to_B, kpts_A_to_B = warp_kpts(kpts_A, + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"],) + mask_B_to_A, kpts_B_to_A = warp_kpts(kpts_B, + batch["im_B_depth"], + batch["im_A_depth"], + batch["T_1to2"].inverse(), + batch["K2"], + batch["K1"],) + return (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) + + def warp_from_homog(self, batch, kpts_A, kpts_B): + kpts_A_to_B = homog_transform(batch["Homog_A_to_B"], kpts_A) + kpts_B_to_A = homog_transform(batch["Homog_A_to_B"].inverse(), kpts_B) + return (None, kpts_A_to_B), (None, kpts_B_to_A) + + def supervised_loss(self, outputs, batch): + kpts_A, kpts_B = self.detector.detect(batch, num_keypoints = self.num_keypoints)['keypoints'].clone().chunk(2) + desc_grid_A, desc_grid_B = outputs["description_grid"].chunk(2) + desc_A = F.grid_sample(desc_grid_A.float(), kpts_A[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT + desc_B = F.grid_sample(desc_grid_B.float(), kpts_B[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT + if "im_A_depth" in batch: + (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_depth(batch, kpts_A, kpts_B) + elif "Homog_A_to_B" in batch: + (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_homog(batch, kpts_A, kpts_B) + + with torch.no_grad(): + D_B = torch.cdist(kpts_A_to_B, kpts_B) + D_A = torch.cdist(kpts_A, kpts_B_to_A) + inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values) + * (D_A == D_A.min(dim=-2, keepdim = True).values) + * (D_B < 0.01) + * (D_A < 0.01)) + + logP_A_B = dual_log_softmax_matcher(desc_A, desc_B, + normalize = self.normalize_descriptions, + inv_temperature = self.inv_temp) + neg_log_likelihood = -logP_A_B[inds[:,0], inds[:,1], inds[:,2]].mean() + self.tracked_metrics["neg_log_likelihood"] = (0.99 * self.tracked_metrics.get("neg_log_likelihood", neg_log_likelihood.detach().item()) + 0.01 * neg_log_likelihood.detach().item()) + if np.random.rand() > 0.99: + print(self.tracked_metrics["neg_log_likelihood"]) + return neg_log_likelihood + + def forward(self, outputs, batch): + losses = self.supervised_loss(outputs, batch) + return losses \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/detectors/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py b/imcui/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..4a02f621a4a93a30df94c2fe5f6fd0297ce53f95 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py @@ -0,0 +1,76 @@ +import torch +from PIL import Image +import torch.nn as nn +import torchvision.models as tvm +import torch.nn.functional as F +import numpy as np + +from DeDoDe.utils import sample_keypoints, to_pixel_coords, to_normalized_coords, get_best_device + + + +class DeDoDeDetector(nn.Module): + def __init__(self, encoder, decoder, *args, remove_borders = False, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.encoder = encoder + self.decoder = decoder + import torchvision.transforms as transforms + self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.remove_borders = remove_borders + + def forward( + self, + batch, + ): + if "im_A" in batch: + images = torch.cat((batch["im_A"], batch["im_B"])) + else: + images = batch["image"] + features, sizes = self.encoder(images) + logits = 0 + context = None + scales = ["8", "4", "2", "1"] + for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)): + delta_logits, context = self.decoder(feature_map, context = context, scale = scale) + logits = logits + delta_logits.float() # ensure float (need bf16 doesnt have f.interpolate) + if idx < len(scales) - 1: + size = sizes[-(idx+2)] + logits = F.interpolate(logits, size = size, mode = "bicubic", align_corners = False) + context = F.interpolate(context.float(), size = size, mode = "bilinear", align_corners = False) + return {"keypoint_logits" : logits.float()} + + @torch.inference_mode() + def detect(self, batch, num_keypoints = 10_000): + self.train(False) + keypoint_logits = self.forward(batch)["keypoint_logits"] + B,K,H,W = keypoint_logits.shape + keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1) + keypoints, confidence = sample_keypoints(keypoint_p.reshape(B,H,W), + use_nms = False, sample_topk = True, num_samples = num_keypoints, + return_scoremap=True, sharpen = False, upsample = False, + increase_coverage=True, remove_borders = self.remove_borders) + return {"keypoints": keypoints, "confidence": confidence} + + @torch.inference_mode() + def detect_dense(self, batch): + self.train(False) + keypoint_logits = self.forward(batch)["keypoint_logits"] + return {"dense_keypoint_logits": keypoint_logits} + + def read_image(self, im_path, H = 784, W = 784, device=get_best_device()): + pil_im = Image.open(im_path).resize((W, H)) + standard_im = np.array(pil_im)/255. + return self.normalizer(torch.from_numpy(standard_im).permute(2,0,1)).float().to(device)[None] + + def detect_from_path(self, im_path, num_keypoints = 30_000, H = 784, W = 784, dense = False): + batch = {"image": self.read_image(im_path, H = H, W = W)} + if dense: + return self.detect_dense(batch) + else: + return self.detect(batch, num_keypoints = num_keypoints) + + def to_pixel_coords(self, x, H, W): + return to_pixel_coords(x, H, W) + + def to_normalized_coords(self, x, H, W): + return to_normalized_coords(x, H, W) \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py b/imcui/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d8dfbd7747aedad25101dd4b59e9cf950bfe4880 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn +import math + +from DeDoDe.utils import * +import DeDoDe + +class KeyPointLoss(nn.Module): + + def __init__(self, smoothing_size = 1, use_max_logit = False, entropy_target = 80, + num_matches = 1024, jacobian_density_adjustment = False, + matchability_weight = 1, device = "cuda") -> None: + super().__init__() + X = torch.linspace(-1,1,smoothing_size, device = device) + G = (-X**2 / (2 *1/2**2)).exp() + G = G/G.sum() + self.use_max_logit = use_max_logit + self.entropy_target = entropy_target + self.smoothing_kernel = G[None, None, None,:] + self.smoothing_size = smoothing_size + self.tracked_metrics = {} + self.center = None + self.num_matches = num_matches + self.jacobian_density_adjustment = jacobian_density_adjustment + self.matchability_weight = matchability_weight + + def compute_consistency(self, logits_A, logits_B_to_A, mask = None): + + masked_logits_A = torch.full_like(logits_A, -torch.inf) + masked_logits_A[mask] = logits_A[mask] + + masked_logits_B_to_A = torch.full_like(logits_B_to_A, -torch.inf) + masked_logits_B_to_A[mask] = logits_B_to_A[mask] + + log_p_A = masked_logits_A.log_softmax(dim=-1)[mask] + log_p_B_to_A = masked_logits_B_to_A.log_softmax(dim=-1)[mask] + + return self.compute_jensen_shannon_div(log_p_A, log_p_B_to_A) + + def compute_joint_neg_log_likelihood(self, logits_A, logits_B_to_A, detections_A = None, detections_B_to_A = None, mask = None, device = "cuda", dtype = torch.float32, num_matches = None): + B, K, HW = logits_A.shape + logits_A, logits_B_to_A = logits_A.to(dtype), logits_B_to_A.to(dtype) + mask = mask[:,None].expand(B, K, HW).reshape(B, K*HW) + log_p_B_to_A = self.masked_log_softmax(logits_B_to_A.reshape(B,K*HW), mask = mask) + log_p_A = self.masked_log_softmax(logits_A.reshape(B,K*HW), mask = mask) + log_p = log_p_A + log_p_B_to_A + if detections_A is None: + detections_A = torch.zeros_like(log_p_A) + if detections_B_to_A is None: + detections_B_to_A = torch.zeros_like(log_p_B_to_A) + detections_A = detections_A.reshape(B, HW) + detections_A[~mask] = 0 + detections_B_to_A = detections_B_to_A.reshape(B, HW) + detections_B_to_A[~mask] = 0 + log_p_target = log_p.detach() + 50*detections_A + 50*detections_B_to_A + num_matches = self.num_matches if num_matches is None else num_matches + best_k = -(-log_p_target).flatten().kthvalue(k = B * num_matches, dim=-1).values + p_target = (log_p_target > best_k[..., None]).float().reshape(B,K*HW)/num_matches + return self.compute_cross_entropy(log_p_A[mask], p_target[mask]) + self.compute_cross_entropy(log_p_B_to_A[mask], p_target[mask]) + + def compute_jensen_shannon_div(self, log_p, log_q): + return 1/2 * (self.compute_kl_div(log_p, log_q) + self.compute_kl_div(log_q, log_p)) + + def compute_kl_div(self, log_p, log_q): + return (log_p.exp()*(log_p-log_q)).sum(dim=-1) + + def masked_log_softmax(self, logits, mask): + masked_logits = torch.full_like(logits, -torch.inf) + masked_logits[mask] = logits[mask] + log_p = masked_logits.log_softmax(dim=-1) + return log_p + + def masked_softmax(self, logits, mask): + masked_logits = torch.full_like(logits, -torch.inf) + masked_logits[mask] = logits[mask] + log_p = masked_logits.softmax(dim=-1) + return log_p + + def compute_detection_img(self, detections, mask, B, H, W, device = "cuda"): + kernel_size = 5 + X = torch.linspace(-2,2,kernel_size, device = device) + G = (-X**2 / (2 * (1/2)**2)).exp() # half pixel std + G = G/G.sum() + det_smoothing_kernel = G[None, None, None,:] + det_img = torch.zeros((B,1,H,W), device = device) + for b in range(B): + valid_detections = (detections[b][mask[b]]).int() + det_img[b,0][valid_detections[:,1], valid_detections[:,0]] = 1 + det_img = F.conv2d(det_img, weight = det_smoothing_kernel, padding = (kernel_size//2, 0)) + det_img = F.conv2d(det_img, weight = det_smoothing_kernel.mT, padding = (0, kernel_size//2)) + return det_img + + def compute_cross_entropy(self, log_p_hat, p): + return -(log_p_hat * p).sum(dim=-1) + + def compute_matchability(self, keypoint_p, has_depth, B, K, H, W, device = "cuda"): + smooth_keypoint_p = F.conv2d(keypoint_p.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (self.smoothing_size//2,0)) + smooth_keypoint_p = F.conv2d(smooth_keypoint_p, weight = self.smoothing_kernel.mT, padding = (0,self.smoothing_size//2)) + log_p_hat = (smooth_keypoint_p+1e-8).log().reshape(B,H*W).log_softmax(dim=-1) + smooth_has_depth = F.conv2d(has_depth.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (0,self.smoothing_size//2)) + smooth_has_depth = F.conv2d(smooth_has_depth, weight = self.smoothing_kernel.mT, padding = (self.smoothing_size//2,0)).reshape(B,H*W) + p = smooth_has_depth/smooth_has_depth.sum(dim=-1,keepdim=True) + return self.compute_cross_entropy(log_p_hat, p) - self.compute_cross_entropy((p+1e-12).log(), p) + + def supervised_loss(self, outputs, batch): + keypoint_logits_A, keypoint_logits_B = outputs["keypoint_logits"].chunk(2) + B, K, H, W = keypoint_logits_A.shape + + detections_A, detections_B = batch["detections_A"], batch["detections_B"] + + gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=H, + W=W, + ) + gt_warp_B_to_A, valid_mask_B_to_A = get_gt_warp( + batch["im_B_depth"], + batch["im_A_depth"], + batch["T_1to2"].inverse(), + batch["K2"], + batch["K1"], + H=H, + W=W, + ) + keypoint_logits_A = keypoint_logits_A.reshape(B, K, H*W) + keypoint_logits_B = keypoint_logits_B.reshape(B, K, H*W) + keypoint_logits = torch.cat((keypoint_logits_A, keypoint_logits_B)) + + B = 2*B + gt_warp = torch.cat((gt_warp_A_to_B, gt_warp_B_to_A)) + valid_mask = torch.cat((valid_mask_A_to_B, valid_mask_B_to_A)) + valid_mask = valid_mask.reshape(B,H*W) + binary_mask = valid_mask == 1 + detections = torch.cat((detections_A, detections_B)) + legit_detections = ((detections > 0).prod(dim = -1) * (detections[...,0] < W) * (detections[...,1] < H)).bool() + det_imgs_A, det_imgs_B = self.compute_detection_img(detections, legit_detections, B, H, W).chunk(2) + det_imgs = torch.cat((det_imgs_A, det_imgs_B)) + det_imgs_backwarped = F.grid_sample(torch.cat((det_imgs_B, det_imgs_A)).reshape(B,1,H,W), + gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic") + + keypoint_logits_backwarped = F.grid_sample(torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B,K,H,W), + gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic") + + keypoint_logits_backwarped = (keypoint_logits_backwarped).reshape(B,K,H*W) + + + depth = F.interpolate(torch.cat((batch["im_A_depth"][:,None],batch["im_B_depth"][:,None]),dim=0), size = (H,W), mode = "bilinear", align_corners=False) + has_depth = (depth > 0).float().reshape(B,H*W) + + joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(keypoint_logits, keypoint_logits_backwarped, + mask = binary_mask, detections_A = det_imgs, + detections_B_to_A = det_imgs_backwarped).mean() + keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1) + matchability_loss = self.compute_matchability(keypoint_p, has_depth, B, K, H, W).mean() + B = B//2 + kpts_A = sample_keypoints(keypoint_p[:B].reshape(B,H,W), + use_nms = False, sample_topk = True, num_samples = 4*2048) + kpts_B = sample_keypoints(keypoint_p[B:].reshape(B,H,W), + use_nms = False, sample_topk = True, num_samples = 4*2048) + kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:], + align_corners=False, mode = 'bilinear')[...,0].mT + legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:], + align_corners=False, mode = 'bilinear')[...,0,:,0] + percent_inliers = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0] < 0.01).float().mean() + self.tracked_metrics["mega_percent_inliers"] = (0.9 * self.tracked_metrics.get("mega_percent_inliers", percent_inliers) + 0.1 * percent_inliers) + + tot_loss = joint_log_likelihood_loss + self.matchability_weight * matchability_loss# + if torch.rand(1) > 1: + print(f"Precent Inlier: {self.tracked_metrics.get('mega_percent_inliers', 0)}") + print(f"{joint_log_likelihood_loss=} {matchability_loss=}") + print(f"Total Loss: {tot_loss.item()}") + return tot_loss + + def forward(self, outputs, batch): + + if not isinstance(outputs, list): + outputs = [outputs] + losses = 0 + for output in outputs: + losses = losses + self.supervised_loss(output, batch) + return losses \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/encoder.py b/imcui/third_party/DeDoDe/DeDoDe/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..91880e7d5e98b02259127b107a459401b99bb157 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/encoder.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torchvision.models as tvm + + +class VGG19(nn.Module): + def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + # Maxpool layers: 6, 13, 26, 39 + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + feats = [] + sizes = [] + for layer in self.layers: + if isinstance(layer, nn.MaxPool2d): + feats.append(x) + sizes.append(x.shape[-2:]) + x = layer(x) + return feats, sizes + +class VGG(nn.Module): + def __init__(self, size = "19", pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + if size == "11": + self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22]) + elif size == "13": + self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28]) + elif size == "19": + self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + # Maxpool layers: 6, 13, 26, 39 + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + feats = [] + sizes = [] + for layer in self.layers: + if isinstance(layer, nn.MaxPool2d): + feats.append(x) + sizes.append(x.shape[-2:]) + x = layer(x) + return feats, sizes + +class FrozenDINOv2(nn.Module): + def __init__(self, amp = True, amp_dtype = torch.float16, dinov2_weights = None): + super().__init__() + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu") + from .transformer import vit_large + vit_kwargs = dict(img_size= 518, + patch_size= 14, + init_values = 1.0, + ffn_layer = "mlp", + block_chunks = 0, + ) + dinov2_vitl14 = vit_large(**vit_kwargs).eval() + dinov2_vitl14.load_state_dict(dinov2_weights) + self.amp = amp + self.amp_dtype = amp_dtype + if self.amp: + dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) + self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP + def forward(self, x): + B, C, H, W = x.shape + if self.dinov2_vitl14[0].device != x.device: + self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) + with torch.inference_mode(): + dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype)) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14) + return [features_16.clone()], [(H//14, W//14)] # clone from inference mode to use in autograd + +class VGG_DINOv2(nn.Module): + def __init__(self, vgg_kwargs = None, dinov2_kwargs = None): + assert vgg_kwargs is not None and dinov2_kwargs is not None, "Input kwargs pls" + super().__init__() + self.vgg = VGG(**vgg_kwargs) + self.frozen_dinov2 = FrozenDINOv2(**dinov2_kwargs) + + def forward(self, x): + feats_vgg, sizes_vgg = self.vgg(x) + feat_dinov2, size_dinov2 = self.frozen_dinov2(x) + return feats_vgg + feat_dinov2, sizes_vgg + size_dinov2 diff --git a/imcui/third_party/DeDoDe/DeDoDe/matchers/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py b/imcui/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc76cad77ee403d7d5ab729c786982a47fbe6e9 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py @@ -0,0 +1,38 @@ +import torch +from PIL import Image +import torch.nn as nn +import torchvision.models as tvm +import torch.nn.functional as F +import numpy as np +from DeDoDe.utils import dual_softmax_matcher, to_pixel_coords, to_normalized_coords + +class DualSoftMaxMatcher(nn.Module): + @torch.inference_mode() + def match(self, keypoints_A, descriptions_A, + keypoints_B, descriptions_B, P_A = None, P_B = None, + normalize = False, inv_temp = 1, threshold = 0.0): + if isinstance(descriptions_A, list): + matches = [self.match(k_A[None], d_A[None], k_B[None], d_B[None], normalize = normalize, + inv_temp = inv_temp, threshold = threshold) + for k_A,d_A,k_B,d_B in + zip(keypoints_A, descriptions_A, keypoints_B, descriptions_B)] + matches_A = torch.cat([m[0] for m in matches]) + matches_B = torch.cat([m[1] for m in matches]) + inds = torch.cat([m[2] + b for b, m in enumerate(matches)]) + return matches_A, matches_B, inds + + P = dual_softmax_matcher(descriptions_A, descriptions_B, + normalize = normalize, inv_temperature=inv_temp, + ) + inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values) + * (P == P.max(dim=-2, keepdim = True).values) * (P > threshold)) + batch_inds = inds[:,0] + matches_A = keypoints_A[batch_inds, inds[:,1]] + matches_B = keypoints_B[batch_inds, inds[:,2]] + return matches_A, matches_B, batch_inds + + def to_pixel_coords(self, x_A, x_B, H_A, W_A, H_B, W_B): + return to_pixel_coords(x_A, H_A, W_A), to_pixel_coords(x_B, H_B, W_B) + + def to_normalized_coords(self, x_A, x_B, H_A, W_A, H_B, W_B): + return to_normalized_coords(x_A, H_A, W_A), to_normalized_coords(x_B, H_B, W_B) \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0775d438f94b6095d094e119f788368170694c4c --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py @@ -0,0 +1,3 @@ +from .dedode_models import dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G + + \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py new file mode 100644 index 0000000000000000000000000000000000000000..deac312b81691024c2124ebd825f374f9e8c9db1 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py @@ -0,0 +1,249 @@ +import torch +import torch.nn as nn + +from DeDoDe.detectors.dedode_detector import DeDoDeDetector +from DeDoDe.descriptors.dedode_descriptor import DeDoDeDescriptor +from DeDoDe.decoder import ConvRefiner, Decoder +from DeDoDe.encoder import VGG19, VGG, VGG_DINOv2 +from DeDoDe.utils import get_best_device + + +def dedode_detector_B(device = get_best_device(), weights = None): + residual = True + hidden_blocks = 5 + amp_dtype = torch.float16 + amp = True + NUM_PROTOTYPES = 1 + conv_refiner = nn.ModuleDict( + { + "8": ConvRefiner( + 512, + 512, + 256 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + "4": ConvRefiner( + 256+256, + 256, + 128 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "2": ConvRefiner( + 128+128, + 64, + 32 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "1": ConvRefiner( + 64 + 32, + 32, + 1 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + } + ) + encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype) + decoder = Decoder(conv_refiner) + model = DeDoDeDetector(encoder = encoder, decoder = decoder).to(device) + if weights is not None: + model.load_state_dict(weights) + return model + + +def dedode_detector_L(device = get_best_device(), weights = None, remove_borders = False): + if weights is None: + weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_L_v2.pth", map_location = device) + NUM_PROTOTYPES = 1 + residual = True + hidden_blocks = 8 + amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + amp = True + conv_refiner = nn.ModuleDict( + { + "8": ConvRefiner( + 512, + 512, + 256 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + "4": ConvRefiner( + 256+256, + 256, + 128 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "2": ConvRefiner( + 128+128, + 128, + 64 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "1": ConvRefiner( + 64 + 64, + 64, + 1 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + } + ) + encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype) + decoder = Decoder(conv_refiner) + model = DeDoDeDetector(encoder = encoder, decoder = decoder, remove_borders = remove_borders).to(device) + if weights is not None: + model.load_state_dict(weights) + return model + + + +def dedode_descriptor_B(device = get_best_device(), weights = None): + if weights is None: + weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth", map_location=device) + NUM_PROTOTYPES = 256 # == descriptor size + residual = True + hidden_blocks = 5 + amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + amp = True + conv_refiner = nn.ModuleDict( + { + "8": ConvRefiner( + 512, + 512, + 256 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + "4": ConvRefiner( + 256+256, + 256, + 128 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "2": ConvRefiner( + 128+128, + 64, + 32 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "1": ConvRefiner( + 64 + 32, + 32, + 1 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + } + ) + encoder = VGG(size = "19", pretrained = False, amp = amp, amp_dtype = amp_dtype) + decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES) + model = DeDoDeDescriptor(encoder = encoder, decoder = decoder).to(device) + if weights is not None: + model.load_state_dict(weights) + return model + +def dedode_descriptor_G(device = get_best_device(), weights = None, dinov2_weights = None): + if weights is None: + weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth", map_location=device) + NUM_PROTOTYPES = 256 # == descriptor size + residual = True + hidden_blocks = 5 + amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + amp = True + conv_refiner = nn.ModuleDict( + { + "14": ConvRefiner( + 1024, + 768, + 512 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + "8": ConvRefiner( + 512 + 512, + 512, + 256 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + "4": ConvRefiner( + 256+256, + 256, + 128 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "2": ConvRefiner( + 128+128, + 64, + 32 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "1": ConvRefiner( + 64 + 32, + 32, + 1 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + } + ) + vgg_kwargs = dict(size = "19", pretrained = False, amp = amp, amp_dtype = amp_dtype) + dinov2_kwargs = dict(amp = amp, amp_dtype = amp_dtype, dinov2_weights = dinov2_weights) + encoder = VGG_DINOv2(vgg_kwargs = vgg_kwargs, dinov2_kwargs = dinov2_kwargs) + decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES) + model = DeDoDeDescriptor(encoder = encoder, decoder = decoder).to(device) + if weights is not None: + model.load_state_dict(weights) + return model \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/train.py b/imcui/third_party/DeDoDe/DeDoDe/train.py new file mode 100644 index 0000000000000000000000000000000000000000..342cdd636c8d5ae0b693bf6220ba088bdbc2035c --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/train.py @@ -0,0 +1,76 @@ +import torch +from tqdm import tqdm +from DeDoDe.utils import to_cuda, to_best_device + + +def train_step(train_batch, model, objective, optimizer, grad_scaler = None,**kwargs): + optimizer.zero_grad() + out = model(train_batch) + l = objective(out, train_batch) + if grad_scaler is not None: + grad_scaler.scale(l).backward() + grad_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) + grad_scaler.step(optimizer) + grad_scaler.update() + else: + l.backward() + optimizer.step() + return {"train_out": out, "train_loss": l.item()} + + +def train_k_steps( + n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler = None, progress_bar=True +): + for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar, mininterval = 10.): + batch = next(dataloader) + model.train(True) + batch = to_best_device(batch) + train_step( + train_batch=batch, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + n=n, + grad_scaler = grad_scaler, + ) + lr_scheduler.step() + + +def train_epoch( + dataloader=None, + model=None, + objective=None, + optimizer=None, + lr_scheduler=None, + epoch=None, +): + model.train(True) + print(f"At epoch {epoch}") + for batch in tqdm(dataloader, mininterval=5.0): + batch = to_best_device(batch) + train_step( + train_batch=batch, model=model, objective=objective, optimizer=optimizer + ) + lr_scheduler.step() + return { + "model": model, + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + "epoch": epoch, + } + + +def train_k_epochs( + start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler +): + for epoch in range(start_epoch, end_epoch + 1): + train_epoch( + dataloader=dataloader, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + ) diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..031d52e998bc18f6d5264fb8b791a6339cf793b5 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/__init__.py @@ -0,0 +1,8 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from DeDoDe.utils import get_grid +from .layers.block import Block +from .layers.attention import MemEffAttention +from .dinov2 import vit_large \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/dinov2.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..b556c63096d17239c8603d5fe626c331963099fd --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/dinov2.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + for param in self.parameters(): + param.requires_grad = False + + @property + def device(self): + return self.cls_token.device + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9b0c94b40967dfdff4f261c127cbd21328c905 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/block.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/imcui/third_party/DeDoDe/DeDoDe/utils.py b/imcui/third_party/DeDoDe/DeDoDe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9475dc8927aa2256fc9d947cc3034dff9420e6c4 --- /dev/null +++ b/imcui/third_party/DeDoDe/DeDoDe/utils.py @@ -0,0 +1,717 @@ +import warnings +import numpy as np +import math +import cv2 +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torch.nn.functional as F +from PIL import Image +from einops import rearrange +import torch +from time import perf_counter + + +def get_best_device(verbose = False): + device = torch.device('cpu') + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + if verbose: print (f"Fastest device found is: {device}") + return device + + +def recover_pose(E, kpts0, kpts1, K0, K1, mask): + best_num_inliers = 0 + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + + +# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + +def get_grid(B,H,W, device = get_best_device()): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=device + ) + for n in (B, H, W) + ] + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + return x1_n + +@torch.no_grad() +def finite_diff_hessian(f: tuple(["B", "H", "W"]), device = get_best_device()): + dxx = torch.tensor([[0,0,0],[1,-2,1],[0,0,0]], device = device)[None,None]/2 + dxy = torch.tensor([[1,0,-1],[0,0,0],[-1,0,1]], device = device)[None,None]/4 + dyy = dxx.mT + Hxx = F.conv2d(f[:,None], dxx, padding = 1)[:,0] + Hxy = F.conv2d(f[:,None], dxy, padding = 1)[:,0] + Hyy = F.conv2d(f[:,None], dyy, padding = 1)[:,0] + H = torch.stack((Hxx, Hxy, Hxy, Hyy), dim = -1).reshape(*f.shape,2,2) + return H + +def finite_diff_grad(f: tuple(["B", "H", "W"]), device = get_best_device()): + dx = torch.tensor([[0,0,0],[-1,0,1],[0,0,0]],device = device)[None,None]/2 + dy = dx.mT + gx = F.conv2d(f[:,None], dx, padding = 1) + gy = F.conv2d(f[:,None], dy, padding = 1) + g = torch.cat((gx, gy), dim = 1) + return g + +def fast_inv_2x2(matrix: tuple[...,2,2], eps = 1e-10): + return 1/(torch.linalg.det(matrix)[...,None,None]+eps) * torch.stack((matrix[...,1,1],-matrix[...,0,1], + -matrix[...,1,0],matrix[...,0,0]),dim=-1).reshape(*matrix.shape) + +def newton_step(f:tuple["B","H","W"], inds, device = get_best_device()): + B,H,W = f.shape + Hess = finite_diff_hessian(f).reshape(B,H*W,2,2) + Hess = torch.gather(Hess, dim = 1, index = inds[...,None].expand(B,-1,2,2)) + grad = finite_diff_grad(f).reshape(B,H*W,2) + grad = torch.gather(grad, dim = 1, index = inds) + Hessinv = fast_inv_2x2(Hess-torch.eye(2, device = device)[None,None]) + step = (Hessinv @ grad[...,None]) + return step[...,0] + +@torch.no_grad() +def sample_keypoints(scoremap, num_samples = 8192, device = get_best_device(), use_nms = True, + sample_topk = False, return_scoremap = False, sharpen = False, upsample = False, + increase_coverage = False, remove_borders = False): + #scoremap = scoremap**2 + log_scoremap = (scoremap+1e-10).log() + if upsample: + log_scoremap = F.interpolate(log_scoremap[:,None], scale_factor = 3, mode = "bicubic", align_corners = False)[:,0]#.clamp(min = 0) + scoremap = log_scoremap.exp() + B,H,W = scoremap.shape + if increase_coverage: + weights = (-torch.linspace(-2, 2, steps = 51, device = device)**2).exp()[None,None] + # 10000 is just some number for maybe numerical stability, who knows. :), result is invariant anyway + local_density_x = F.conv2d((scoremap[:,None]+1e-6)*10000,weights[...,None,:], padding = (0,51//2)) + local_density = F.conv2d(local_density_x, weights[...,None], padding = (51//2,0))[:,0] + scoremap = scoremap * (local_density+1e-8)**(-1/2) + grid = get_grid(B,H,W, device=device).reshape(B,H*W,2) + if sharpen: + laplace_operator = torch.tensor([[[[0,1,0],[1,-4,1],[0,1,0]]]], device = device)/4 + scoremap = scoremap[:,None] - 0.5 * F.conv2d(scoremap[:,None], weight = laplace_operator, padding = 1) + scoremap = scoremap[:,0].clamp(min = 0) + if use_nms: + scoremap = scoremap * (scoremap == F.max_pool2d(scoremap, (3, 3), stride = 1, padding = 1)) + if remove_borders: + frame = torch.zeros_like(scoremap) + # we hardcode 4px, could do it nicer, but whatever + frame[...,4:-4, 4:-4] = 1 + scoremap = scoremap * frame + if sample_topk: + inds = torch.topk(scoremap.reshape(B,H*W), k = num_samples).indices + else: + inds = torch.multinomial(scoremap.reshape(B,H*W), num_samples = num_samples, replacement=False) + kps = torch.gather(grid, dim = 1, index = inds[...,None].expand(B,num_samples,2)) + if return_scoremap: + return kps, torch.gather(scoremap.reshape(B,H*W), dim = 1, index = inds) + return kps + +@torch.no_grad() +def jacobi_determinant(warp, certainty, R = 3, device = get_best_device(), dtype = torch.float32): + t = perf_counter() + *dims, _ = warp.shape + warp = warp.to(dtype) + certainty = certainty.to(dtype) + + dtype = warp.dtype + match_regions = torch.zeros((*dims, 4, R, R), device = device).to(dtype) + match_regions[:,1:-1, 1:-1] = warp.unfold(1,R,1).unfold(2,R,1) + match_regions = rearrange(match_regions,"B H W D R1 R2 -> B H W (R1 R2) D") - warp[...,None,:] + + match_regions_cert = torch.zeros((*dims, R, R), device = device).to(dtype) + match_regions_cert[:,1:-1, 1:-1] = certainty.unfold(1,R,1).unfold(2,R,1) + match_regions_cert = rearrange(match_regions_cert,"B H W R1 R2 -> B H W (R1 R2)")[..., None] + + #print("Time for unfold", perf_counter()-t) + #t = perf_counter() + *dims, N, D = match_regions.shape + # standardize: + mu, sigma = match_regions.mean(dim=(-2,-1), keepdim = True), match_regions.std(dim=(-2,-1),keepdim=True) + match_regions = (match_regions-mu)/(sigma+1e-6) + x_a, x_b = match_regions.chunk(2,-1) + + + A = torch.zeros((*dims,2*x_a.shape[-2],4), device = device).to(dtype) + A[...,::2,:2] = x_a * match_regions_cert + A[...,1::2,2:] = x_a * match_regions_cert + + a_block = A[...,::2,:2] + ata = a_block.mT @ a_block + #print("Time for ata", perf_counter()-t) + #t = perf_counter() + + #atainv = torch.linalg.inv(ata+1e-5*torch.eye(2,device=device).to(dtype)) + atainv = fast_inv_2x2(ata) + ATA_inv = torch.zeros((*dims, 4, 4), device = device, dtype = dtype) + ATA_inv[...,:2,:2] = atainv + ATA_inv[...,2:,2:] = atainv + atb = A.mT @ (match_regions_cert*x_b).reshape(*dims,N*2,1) + theta = ATA_inv @ atb + #print("Time for theta", perf_counter()-t) + #t = perf_counter() + + J = theta.reshape(*dims, 2, 2) + abs_J_det = torch.linalg.det(J+1e-8*torch.eye(2,2,device = device).expand(*dims,2,2)).abs() # Note: This should always be positive for correct warps, but still taking abs here + abs_J_logdet = (abs_J_det+1e-12).log() + B = certainty.shape[0] + # Handle outliers + robust_abs_J_logdet = abs_J_logdet.clamp(-3, 3) # Shouldn't be more that exp(3) \approx 8 times zoom + #print("Time for logdet", perf_counter()-t) + #t = perf_counter() + + return robust_abs_J_logdet + +def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): + + if H is None: + B,H,W = depth1.shape + else: + B = depth1.shape[0] + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=depth1.device + ) + for n in (B, H, W) + ] + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + depth_interpolation_mode = depth_interpolation_mode, + relative_depth_error_threshold = relative_depth_error_threshold, + ) + prob = mask.float().reshape(B, H, W) + x2 = x2.reshape(B, H, W, 2) + return torch.cat((x1_n.reshape(B,H,W,2),x2),dim=-1), prob + +def unnormalize_coords(x_n,h,w): + x = torch.stack( + (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + return x + + +def rotate_intrinsic(K, n): + base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + rot = np.linalg.matrix_power(base_rot, n) + return rot @ K + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) + return np.dot(scales, K) + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + + +# From Patch2Pix https://github.com/GrumpyZhou/patch2pix +def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias = False)) + return TupleCompose(ops) + + +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False): + ops = [] + if resize: + ops.append(TupleResize(resize, antialias = True)) + if clahe: + ops.append(TupleClahe()) + if normalize: + ops.append(TupleToTensorScaled()) + ops.append( + TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) # Imagenet mean/std + else: + if unscale: + ops.append(TupleToTensorUnscaled()) + else: + ops.append(TupleToTensorScaled()) + return TupleCompose(ops) + +class Clahe: + def __init__(self, cliplimit = 2, blocksize = 8) -> None: + self.clahe = cv2.createCLAHE(cliplimit,(blocksize,blocksize)) + def __call__(self, im): + im_hsv = cv2.cvtColor(np.array(im),cv2.COLOR_RGB2HSV) + im_v = self.clahe.apply(im_hsv[:,:,2]) + im_hsv[...,2] = im_v + im_clahe = cv2.cvtColor(im_hsv,cv2.COLOR_HSV2RGB) + return Image.fromarray(im_clahe) + +class TupleClahe: + def __init__(self, cliplimit = 8, blocksize = 8) -> None: + self.clahe = Clahe(cliplimit,blocksize) + def __call__(self, ims): + return [self.clahe(im) for im in ims] + +class ToTensorScaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" + + def __call__(self, im): + if not isinstance(im, torch.Tensor): + im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) + im /= 255.0 + return torch.from_numpy(im) + else: + return im + + def __repr__(self): + return "ToTensorScaled(./255)" + + +class TupleToTensorScaled(object): + def __init__(self): + self.to_tensor = ToTensorScaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorScaled(./255)" + + +class ToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __call__(self, im): + return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) + + def __repr__(self): + return "ToTensorUnscaled()" + + +class TupleToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __init__(self): + self.to_tensor = ToTensorUnscaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorUnscaled()" + + +class TupleResize(object): + def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias = None): + self.size = size + self.resize = transforms.Resize(size, mode, antialias = antialias) + + def __call__(self, im_tuple): + return [self.resize(im) for im in im_tuple] + + def __repr__(self): + return "TupleResize(size={})".format(self.size) + +class Normalize: + def __call__(self,im): + mean = im.mean(dim=(1,2), keepdims=True) + std = im.std(dim=(1,2), keepdims=True) + return (im-mean)/std + + +class TupleNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + self.normalize = transforms.Normalize(mean=mean, std=std) + + def __call__(self, im_tuple): + c,h,w = im_tuple[0].shape + if c > 3: + warnings.warn(f"Number of channels {c=} > 3, assuming first 3 are rgb") + return [self.normalize(im[:3]) for im in im_tuple] + + def __repr__(self): + return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) + + +class TupleCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, im_tuple): + for t in self.transforms: + im_tuple = t(im_tuple) + return im_tuple + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + if depth_interpolation_mode == "combined": + # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation + if smooth_mask: + raise NotImplementedError("Combined bilinear and NN warp not implemented") + valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "bilinear", + relative_depth_error_threshold = relative_depth_error_threshold) + valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "nearest-exact", + relative_depth_error_threshold = relative_depth_error_threshold) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + warp = warp_bilinear.clone() + warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + valid = valid_bilinear | valid_nearest + return valid, warp + + + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + )[:, 0, :, 0] + + relative_depth_error = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() + if not smooth_mask: + consistent_mask = relative_depth_error < relative_depth_error_threshold + else: + consistent_mask = (-relative_depth_error/smooth_mask).exp() + valid_mask = nonzero_mask * covisible_mask * consistent_mask + if return_relative_depth_error: + return relative_depth_error, w_kpts0 + else: + return valid_mask, w_kpts0 + +imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) +imagenet_std = torch.tensor([0.229, 0.224, 0.225]) + + +def numpy_to_pil(x: np.ndarray): + """ + Args: + x: Assumed to be of shape (h,w,c) + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if x.max() <= 1.01: + x *= 255 + x = x.astype(np.uint8) + return Image.fromarray(x) + + +def tensor_to_pil(x, unnormalize=False, autoscale = False): + if unnormalize: + x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) + if autoscale: + if x.max() == x.min(): + warnings.warn("x max == x min, cant autoscale") + else: + x = (x-x.min())/(x.max()-x.min()) + + x = x.detach().permute(1, 2, 0).cpu().numpy() + x = np.clip(x, 0.0, 1.0) + return numpy_to_pil(x) + + +def to_cuda(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cuda() + return batch + + +def to_best_device(batch, device=get_best_device()): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + return batch + + +def to_cpu(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cpu() + return batch + + +def get_pose(calib): + w, h = np.array(calib["imsize"])[0] + return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w + + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans + +def to_pixel_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + w1 * (flow[..., 0] + 1) / 2, + h1 * (flow[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return flow + +def to_normalized_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + 2 * (flow[..., 0]) / w1 - 1, + 2 * (flow[..., 1]) / h1 - 1, + ), + axis=-1, + ) + ) + return flow + + +def warp_to_pixel_coords(warp, h1, w1, h2, w2): + warp1 = warp[..., :2] + warp1 = ( + torch.stack( + ( + w1 * (warp1[..., 0] + 1) / 2, + h1 * (warp1[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + warp2 = warp[..., 2:] + warp2 = ( + torch.stack( + ( + w2 * (warp2[..., 0] + 1) / 2, + h2 * (warp2[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return torch.cat((warp1,warp2), dim=-1) + + +def to_homogeneous(x): + ones = torch.ones_like(x[...,-1:]) + return torch.cat((x, ones), dim = -1) + +def from_homogeneous(xh, eps = 1e-12): + return xh[...,:-1] / (xh[...,-1:]+eps) + +def homog_transform(Homog, x): + xh = to_homogeneous(x) + yh = (Homog @ xh.mT).mT + y = from_homogeneous(yh) + return y + +def get_homog_warp(Homog, H, W, device = get_best_device()): + grid = torch.meshgrid(torch.linspace(-1+1/H,1-1/H,H, device = device), torch.linspace(-1+1/W,1-1/W,W, device = device)) + + x_A = torch.stack((grid[1], grid[0]), dim = -1)[None] + x_A_to_B = homog_transform(Homog, x_A) + mask = ((x_A_to_B > -1) * (x_A_to_B < 1)).prod(dim=-1).float() + return torch.cat((x_A.expand(*x_A_to_B.shape), x_A_to_B),dim=-1), mask + +def dual_log_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False): + B, N, C = desc_A.shape + if normalize: + desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True) + desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True) + corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature + else: + corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature + logP = corr.log_softmax(dim = -2) + corr.log_softmax(dim= -1) + return logP + +def dual_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False): + if len(desc_A.shape) < 3: + desc_A, desc_B = desc_A[None], desc_B[None] + B, N, C = desc_A.shape + if normalize: + desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True) + desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True) + corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature + else: + corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature + P = corr.softmax(dim = -2) * corr.softmax(dim= -1) + return P + +def conditional_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False): + if len(desc_A.shape) < 3: + desc_A, desc_B = desc_A[None], desc_B[None] + B, N, C = desc_A.shape + if normalize: + desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True) + desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True) + corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature + else: + corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature + P_B_cond_A = corr.softmax(dim = -1) + P_A_cond_B = corr.softmax(dim = -2) + + return P_A_cond_B, P_B_cond_A \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/data_prep/prep_keypoints.py b/imcui/third_party/DeDoDe/data_prep/prep_keypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..04fc3c7b110dbb3292b57028f75293325444e242 --- /dev/null +++ b/imcui/third_party/DeDoDe/data_prep/prep_keypoints.py @@ -0,0 +1,103 @@ +import argparse +import numpy as np + +import os + + +base_path = "data/megadepth" +# Remove the trailing / if need be. +if base_path[-1] in ['/', '\\']: + base_path = base_path[: - 1] + + +base_depth_path = os.path.join( + base_path, 'phoenix/S6/zl548/MegaDepth_v1' +) +base_undistorted_sfm_path = os.path.join( + base_path, 'Undistorted_SfM' +) + +scene_ids = os.listdir(base_undistorted_sfm_path) +for scene_id in scene_ids: + if os.path.exists(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy"): + print(f"skipping {scene_id} as it exists") + continue + undistorted_sparse_path = os.path.join( + base_undistorted_sfm_path, scene_id, 'sparse-txt' + ) + if not os.path.exists(undistorted_sparse_path): + print("sparse path doesnt exist") + continue + + depths_path = os.path.join( + base_depth_path, scene_id, 'dense0', 'depths' + ) + if not os.path.exists(depths_path): + print("depths doesnt exist") + + continue + + images_path = os.path.join( + base_undistorted_sfm_path, scene_id, 'images' + ) + if not os.path.exists(images_path): + print("images path doesnt exist") + continue + + # Process cameras.txt + if not os.path.exists(os.path.join(undistorted_sparse_path, 'cameras.txt')): + print("no cameras") + continue + with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f: + raw = f.readlines()[3 :] # skip the header + + camera_intrinsics = {} + for camera in raw: + camera = camera.split(' ') + camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]] + + # Process points3D.txt + with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f: + raw = f.readlines()[3 :] # skip the header + + points3D = {} + for point3D in raw: + point3D = point3D.split(' ') + points3D[int(point3D[0])] = np.array([ + float(point3D[1]), float(point3D[2]), float(point3D[3]) + ]) + + points3D_np = np.zeros((max(points3D.keys())+1, 3)) + for idx, point in points3D.items(): + points3D_np[idx] = point + np.save(f"{base_path}/prep_scene_info/detections3D/detections3D_{scene_id}.npy", + points3D_np) + + # Process images.txt + with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f: + raw = f.readlines()[4 :] # skip the header + + image_id_to_idx = {} + image_names = [] + raw_pose = [] + camera = [] + points3D_id_to_2D = [] + n_points3D = [] + id_to_detections = {} + for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])): + image = image.split(' ') + points = points.split(' ') + + image_id_to_idx[int(image[0])] = idx + + image_name = image[-1].strip('\n') + image_names.append(image_name) + + raw_pose.append([float(elem) for elem in image[1 : -2]]) + camera.append(int(image[-2])) + points_np = np.array(points).astype(np.float32).reshape(len(points)//3, 3) + visible_points = points_np[points_np[:,2] != -1] + id_to_detections[idx] = visible_points + np.save(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy", + id_to_detections) + print(f"{scene_id} done") diff --git a/imcui/third_party/DeDoDe/demo/demo_kpts.py b/imcui/third_party/DeDoDe/demo/demo_kpts.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0dddbe8d5de9b67abe2976ebc5d9c23f412340 --- /dev/null +++ b/imcui/third_party/DeDoDe/demo/demo_kpts.py @@ -0,0 +1,24 @@ +import torch +import cv2 +import numpy as np +from PIL import Image +from DeDoDe import dedode_detector_L +from DeDoDe.utils import * + +def draw_kpts(im, kpts): + kpts = [cv2.KeyPoint(x,y,1.) for x,y in kpts.cpu().numpy()] + im = np.array(im) + ret = cv2.drawKeypoints(im, kpts, None) + return ret + + +if __name__ == "__main__": + device = get_best_device() + detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device)) + im_path = "assets/im_A.jpg" + im = Image.open(im_path) + out = detector.detect_from_path(im_path, num_keypoints = 10_000) + W,H = im.size + kps = out["keypoints"] + kps = detector.to_pixel_coords(kps, H, W) + Image.fromarray(draw_kpts(im, kps[0])).save("demo/keypoints.png") diff --git a/imcui/third_party/DeDoDe/demo/demo_match.py b/imcui/third_party/DeDoDe/demo/demo_match.py new file mode 100644 index 0000000000000000000000000000000000000000..01143998f007ee1d2fb17adc64dcf8387510ac80 --- /dev/null +++ b/imcui/third_party/DeDoDe/demo/demo_match.py @@ -0,0 +1,46 @@ +import torch +from DeDoDe import dedode_detector_L, dedode_descriptor_B +from DeDoDe.matchers.dual_softmax_matcher import DualSoftMaxMatcher +from DeDoDe.utils import * +from PIL import Image +import cv2 +import numpy as np + + +def draw_matches(im_A, kpts_A, im_B, kpts_B): + kpts_A = [cv2.KeyPoint(x,y,1.) for x,y in kpts_A.cpu().numpy()] + kpts_B = [cv2.KeyPoint(x,y,1.) for x,y in kpts_B.cpu().numpy()] + matches_A_to_B = [cv2.DMatch(idx, idx, 0.) for idx in range(len(kpts_A))] + im_A, im_B = np.array(im_A), np.array(im_B) + ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B, + matches_A_to_B, None) + return ret + +if __name__ == "__main__": + device = get_best_device() + detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device)) + descriptor = dedode_descriptor_B(weights = torch.load("dedode_descriptor_B.pth", map_location = device)) + matcher = DualSoftMaxMatcher() + + im_A_path = "assets/im_A.jpg" + im_B_path = "assets/im_B.jpg" + im_A = Image.open(im_A_path) + im_B = Image.open(im_B_path) + W_A, H_A = im_A.size + W_B, H_B = im_B.size + + + detections_A = detector.detect_from_path(im_A_path, num_keypoints = 10_000) + keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"] + detections_B = detector.detect_from_path(im_B_path, num_keypoints = 10_000) + keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"] + description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"] + description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"] + matches_A, matches_B, batch_ids = matcher.match(keypoints_A, description_A, + keypoints_B, description_B, + P_A = P_A, P_B = P_B, + normalize = True, inv_temp=20, threshold = 0.01)#Increasing threshold -> fewer matches, fewer outliers + + matches_A, matches_B = matcher.to_pixel_coords(matches_A, matches_B, H_A, W_A, H_B, W_B) + + Image.fromarray(draw_matches(im_A, matches_A, im_B, matches_B)).save("demo/matches.png") \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/demo/demo_match_dedode_G.py b/imcui/third_party/DeDoDe/demo/demo_match_dedode_G.py new file mode 100644 index 0000000000000000000000000000000000000000..586da9a0949067264d643bfabb29cd541c9e624a --- /dev/null +++ b/imcui/third_party/DeDoDe/demo/demo_match_dedode_G.py @@ -0,0 +1,45 @@ +import torch +from DeDoDe import dedode_detector_L, dedode_descriptor_G +from DeDoDe.matchers.dual_softmax_matcher import DualSoftMaxMatcher +from DeDoDe.utils import * +from PIL import Image +import cv2 +import numpy as np + + +def draw_matches(im_A, kpts_A, im_B, kpts_B): + kpts_A = [cv2.KeyPoint(x,y,1.) for x,y in kpts_A.cpu().numpy()] + kpts_B = [cv2.KeyPoint(x,y,1.) for x,y in kpts_B.cpu().numpy()] + matches_A_to_B = [cv2.DMatch(idx, idx, 0.) for idx in range(len(kpts_A))] + im_A, im_B = np.array(im_A), np.array(im_B) + ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B, + matches_A_to_B, None) + return ret + + +if __name__ == "__main__": + device = get_best_device() + detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device)) + descriptor = dedode_descriptor_G(weights = torch.load("dedode_descriptor_G.pth", map_location = device)) + matcher = DualSoftMaxMatcher() + + im_A_path = "assets/im_A.jpg" + im_B_path = "assets/im_B.jpg" + im_A = Image.open(im_A_path) + im_B = Image.open(im_B_path) + W_A, H_A = im_A.size + W_B, H_B = im_B.size + + detections_A = detector.detect_from_path(im_A_path, num_keypoints = 10_000) + keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"] + detections_B = detector.detect_from_path(im_B_path, num_keypoints = 10_000) + keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"] + description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"] + description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"] + matches_A, matches_B, batch_ids = matcher.match(keypoints_A, description_A, + keypoints_B, description_B, + P_A = P_A, P_B = P_B, + normalize = True, inv_temp=20, threshold = 0.01)#Increasing threshold -> fewer matches, fewer outliers + + matches_A, matches_B = matcher.to_pixel_coords(matches_A, matches_B, H_A, W_A, H_B, W_B) + Image.fromarray(draw_matches(im_A, matches_A, im_B, matches_B)).save("demo/matches.jpg") \ No newline at end of file diff --git a/imcui/third_party/DeDoDe/demo/demo_scoremap.py b/imcui/third_party/DeDoDe/demo/demo_scoremap.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ae13a89ea18b364671a29692d47d550c8e88f0 --- /dev/null +++ b/imcui/third_party/DeDoDe/demo/demo_scoremap.py @@ -0,0 +1,23 @@ +import torch +from PIL import Image +import numpy as np + +from DeDoDe import dedode_detector_L +from DeDoDe.utils import tensor_to_pil, get_best_device + + +if __name__ == "__main__": + device = get_best_device() + detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device)) + H, W = 784, 784 + im_path = "assets/im_A.jpg" + + out = detector.detect_from_path(im_path, dense = True, H = H, W = W) + + logit_map = out["dense_keypoint_logits"].clone() + min = logit_map.max() - 3 + logit_map[logit_map < min] = min + logit_map = (logit_map-min)/(logit_map.max()-min) + logit_map = logit_map.cpu()[0].expand(3,H,W) + im_A = torch.tensor(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1) + tensor_to_pil(logit_map * logit_map + 0.15 * (1-logit_map) * im_A).save("demo/dense_logits.png") diff --git a/imcui/third_party/DeDoDe/setup.py b/imcui/third_party/DeDoDe/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d175ab96e3493a2e53e2daaae99eb822a71b463e --- /dev/null +++ b/imcui/third_party/DeDoDe/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + + +setup( + name="DeDoDe", + packages=find_packages(include= ["DeDoDe*"]), + install_requires=open("requirements.txt", "r").read().split("\n"), + python_requires='>=3.9.0', + version="0.0.1", + author="Johan Edstedt", +) diff --git a/imcui/third_party/EfficientLoFTR/configs/data/__init__.py b/imcui/third_party/EfficientLoFTR/configs/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/EfficientLoFTR/configs/data/base.py b/imcui/third_party/EfficientLoFTR/configs/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..03aab160fa4137ccc04380f94854a56fbb549074 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/configs/data/base.py @@ -0,0 +1,35 @@ +""" +The data config will be the last one merged into the main config. +Setups in data configs will override all existed setups! +""" + +from yacs.config import CfgNode as CN +_CN = CN() +_CN.DATASET = CN() +_CN.TRAINER = CN() + +# training data config +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +# validation set config +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None +_CN.DATASET.VAL_INTRINSIC_PATH = None + +# testing data config +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None +_CN.DATASET.TEST_INTRINSIC_PATH = None + +# dataset config +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +cfg = _CN diff --git a/imcui/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py new file mode 100644 index 0000000000000000000000000000000000000000..876bd4cad7772922d81c83ad3107ab6b8af599a3 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py @@ -0,0 +1,13 @@ +from configs.data.base import cfg + +TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info" + +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" +cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" +cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" + +cfg.DATASET.MGDPT_IMG_RESIZE = 832 +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 + +cfg.DATASET.NPE_NAME = 'megadepth' \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ce0dd463cf09d031464176a5f28a6fe5ba2ad3 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py @@ -0,0 +1,24 @@ +from configs.data.base import cfg + + +TRAIN_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train" +cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" +cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 + +TEST_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" +cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" +cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +# 368 scenes in total for MegaDepth +# (with difficulty balanced (further split each scene to 3 sub-scenes)) +cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100 + +cfg.DATASET.MGDPT_IMG_RESIZE = 832 # for training on 32GB meme GPUs + +cfg.DATASET.NPE_NAME = 'megadepth' \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py b/imcui/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py new file mode 100644 index 0000000000000000000000000000000000000000..ca98ed4b120d699f8de00016f169a83c0c8ddac8 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py @@ -0,0 +1,16 @@ +from configs.data.base import cfg + +TEST_BASE_PATH = "assets/scannet_test_1500" + +cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" +cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" +cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" +cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" +cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" + +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 + +cfg.DATASET.SCAN_IMG_RESIZEX = 640 +cfg.DATASET.SCAN_IMG_RESIZEY = 480 + +cfg.DATASET.NPE_NAME = 'scannet' \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py new file mode 100644 index 0000000000000000000000000000000000000000..24ff5f33b6cf6ee11c4b564050fbe736126b8bc5 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py @@ -0,0 +1,36 @@ +from src.config.default import _CN as cfg + +# training config +cfg.TRAINER.CANONICAL_LR = 8e-3 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 +cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] +cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 +cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +cfg.TRAINER.GRADIENT_CLIPPING = 0.0 +cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2'] +cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True +cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True +cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25 +cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 +cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True + +# model config +cfg.LOFTR.RESOLUTION = (8, 1) +cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even +cfg.LOFTR.ALIGN_CORNER = False +cfg.LOFTR.MP = True # just for reproducing paper, FP16 is much faster on modern GPUs +cfg.LOFTR.REPLACE_NAN = True +cfg.LOFTR.EVAL_TIMES = 5 +cfg.LOFTR.COARSE.NO_FLASH = True # Not use Flash-Attention just for reproducing paper timing +cfg.LOFTR.MATCH_COARSE.THR = 0.2 # recommend 0.2 for full model and 25 for optimized model +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0 +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8 + +# dataset config +cfg.DATASET.FP16 = False + +# full model config +cfg.LOFTR.MATCH_COARSE.FP16MATMUL = False \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py new file mode 100644 index 0000000000000000000000000000000000000000..5c044e49db7ecb31e22570d8295d8ac617dcf64c --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py @@ -0,0 +1,37 @@ +from src.config.default import _CN as cfg + +# training config +cfg.TRAINER.CANONICAL_LR = 8e-3 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 +cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] +cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 +cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +cfg.TRAINER.GRADIENT_CLIPPING = 0.0 +cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2'] +cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True +cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True +cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25 +cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 +cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True + +# model config +cfg.LOFTR.RESOLUTION = (8, 1) +cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even +cfg.LOFTR.ALIGN_CORNER = False +cfg.LOFTR.MP = True # just for reproducing paper, FP16 is much faster on modern GPUs +cfg.LOFTR.REPLACE_NAN = True +cfg.LOFTR.EVAL_TIMES = 5 +cfg.LOFTR.COARSE.NO_FLASH = True # Not use Flash-Attention just for reproducing paper timing +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0 +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8 + +# dataset config +cfg.DATASET.FP16 = False + +# optimized model config +cfg.LOFTR.MATCH_COARSE.FP16MATMUL = True +cfg.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = True +cfg.LOFTR.MATCH_COARSE.THR = 25.0 # recommend 0.2 for full model and 25 for optimized model \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/environment.yaml b/imcui/third_party/EfficientLoFTR/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52bc6f68c1b0d7c0f020453427873370753234bc --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/environment.yaml @@ -0,0 +1,7 @@ +name: eloftr +channels: + - pytorch + - nvidia +dependencies: + - python=3.8 + - pip \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/environment_training.yaml b/imcui/third_party/EfficientLoFTR/environment_training.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9b8f38a07da3d29cec8eb9f5e6a6379a2d9800b --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/environment_training.yaml @@ -0,0 +1,9 @@ +name: eloftr_training +channels: + - pytorch + - nvidia +dependencies: + - python=3.8 + - cudatoolkit=11.3 + - pytorch=1.12.1 + - pip \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/__init__.py b/imcui/third_party/EfficientLoFTR/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/EfficientLoFTR/src/config/default.py b/imcui/third_party/EfficientLoFTR/src/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..03d98095be47b6870cf1475bcfe239de44ee98f9 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/config/default.py @@ -0,0 +1,182 @@ +from yacs.config import CfgNode as CN +_CN = CN() + +############## ↓ LoFTR Pipeline ↓ ############## +_CN.LOFTR = CN() +_CN.LOFTR.BACKBONE_TYPE = 'RepVGG' +_CN.LOFTR.ALIGN_CORNER = False +_CN.LOFTR.RESOLUTION = (8, 1) +_CN.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even +_CN.LOFTR.MP = False +_CN.LOFTR.REPLACE_NAN = False +_CN.LOFTR.EVAL_TIMES = 1 +_CN.LOFTR.HALF = False + +# 1. LoFTR-backbone (local feature CNN) config +_CN.LOFTR.BACKBONE = CN() +_CN.LOFTR.BACKBONE.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.LOFTR.COARSE = CN() +_CN.LOFTR.COARSE.D_MODEL = 256 +_CN.LOFTR.COARSE.D_FFN = 256 +_CN.LOFTR.COARSE.NHEAD = 8 +_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.LOFTR.COARSE.AGG_SIZE0 = 4 +_CN.LOFTR.COARSE.AGG_SIZE1 = 4 +_CN.LOFTR.COARSE.NO_FLASH = False +_CN.LOFTR.COARSE.ROPE = True +_CN.LOFTR.COARSE.NPE = None # [832, 832, long_side, long_side] Suggest setting based on the long side of the input image, especially when the long_side > 832 + +# 3. Coarse-Matching config +_CN.LOFTR.MATCH_COARSE = CN() +_CN.LOFTR.MATCH_COARSE.THR = 0.2 # recommend 0.2 for full model and 25 for optimized model +_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2 +_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock +_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True +_CN.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = False +_CN.LOFTR.MATCH_COARSE.FP16MATMUL = False + +# 4. Fine-Matching config +_CN.LOFTR.MATCH_FINE = CN() +_CN.LOFTR.MATCH_FINE.SPARSE_SPVS = True +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 1.0 +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8 + +# 5. LoFTR Losses +# -- # coarse-level +_CN.LOFTR.LOSS = CN() +_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy'] +_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0 +_CN.LOFTR.LOSS.COARSE_SIGMOID_WEIGHT = 1.0 +_CN.LOFTR.LOSS.LOCAL_WEIGHT = 0.5 +_CN.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = False +_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = False +_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2 = False +# -- - -- # focal loss (coarse) +_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25 +_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0 +_CN.LOFTR.LOSS.POS_WEIGHT = 1.0 +_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0 + +# -- # fine-level +_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0 +_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) + + +############## Dataset ############## +_CN.DATASET = CN() +# 1. data config +# training and validating +_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file +_CN.DATASET.VAL_INTRINSIC_PATH = None +_CN.DATASET.FP16 = False +# testing +_CN.DATASET.TEST_DATA_SOURCE = None +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file +_CN.DATASET.TEST_INTRINSIC_PATH = None + +# 2. dataset config +# general options +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 +_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] + +# scanNet options +_CN.DATASET.SCAN_IMG_RESIZEX = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.SCAN_IMG_RESIZEY = 480 # resize the shorter side, zero-pad bottom-right to square. + +# MegaDepth options +_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE +_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 +_CN.DATASET.MGDPT_DF = 8 + +_CN.DATASET.NPE_NAME = None + +############## Trainer ############## +_CN.TRAINER = CN() +_CN.TRAINER.WORLD_SIZE = 1 +_CN.TRAINER.CANONICAL_BS = 64 +_CN.TRAINER.CANONICAL_LR = 6e-3 +_CN.TRAINER.SCALING = None # this will be calculated automatically +_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning + +# optimizer +_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] +_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime +_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAMW_DECAY = 0.1 + +# step-based warm-up +_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_STEP = 4800 + +# learning rate scheduler +_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR +_CN.TRAINER.MSLR_GAMMA = 0.5 +_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing +_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval + +# plotting related +_CN.TRAINER.ENABLE_PLOTTING = True +_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting +_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] +_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' + +# geometric metrics and pose solver +_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, LO-RANSAC] +_CN.TRAINER.RANSAC_PIXEL_THR = 0.5 +_CN.TRAINER.RANSAC_CONF = 0.99999 +_CN.TRAINER.RANSAC_MAX_ITERS = 10000 +_CN.TRAINER.USE_MAGSACPP = False + +# data sampler for train_dataloader +_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] +# 'scene_balance' config +_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 +_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not +_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not +_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data +# 'random' config +_CN.TRAINER.RDM_REPLACEMENT = True +_CN.TRAINER.RDM_NUM_SAMPLES = None + +# gradient clipping +_CN.TRAINER.GRADIENT_CLIPPING = 0.5 + +# reproducibility +# This seed affects the data sampling. With the same seed, the data sampling is promised +# to be the same. When resume training from a checkpoint, it's better to use a different +# seed, otherwise the sampled data will be exactly the same as before resuming, which will +# cause less unique data items sampled during the entire training. +# Use of different seed values might affect the final training result, since not all data items +# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.) +_CN.TRAINER.SEED = 66 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/imcui/third_party/EfficientLoFTR/src/datasets/megadepth.py b/imcui/third_party/EfficientLoFTR/src/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..5f070b2b6da7ef779d41090773d4c45592e95514 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/datasets/megadepth.py @@ -0,0 +1,133 @@ +import os.path as osp +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from loguru import logger + +from src.utils.dataset import read_megadepth_gray, read_megadepth_depth + + +class MegaDepthDataset(Dataset): + def __init__(self, + root_dir, + npz_path, + mode='train', + min_overlap_score=0.4, + img_resize=None, + df=None, + img_padding=False, + depth_padding=False, + augment_fn=None, + fp16=False, + **kwargs): + """ + Manage one scene(npz_path) of MegaDepth dataset. + + Args: + root_dir (str): megadepth root directory that has `phoenix`. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + mode (str): options are ['train', 'val', 'test'] + min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. + img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. + This is useful during training with batches and testing with memory intensive algorithms. + df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. + img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. + depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. + augment_fn (callable, optional): augments images with pre-defined visual effects. + """ + super().__init__() + self.root_dir = root_dir + self.mode = mode + self.scene_id = npz_path.split('.')[0] + + # prepare scene_info and pair_info + if mode == 'test' and min_overlap_score != 0: + logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") + min_overlap_score = 0 + self.scene_info = np.load(npz_path, allow_pickle=True) + self.pair_infos = self.scene_info['pair_infos'].copy() + + del self.scene_info['pair_infos'] + self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] + + # parameters for image resizing, padding and depthmap padding + if mode == 'train': + assert img_resize is not None and img_padding and depth_padding + self.img_resize = img_resize + self.df = df + self.img_padding = img_padding + self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + + self.fp16 = fp16 + + def __len__(self): + return len(self.pair_infos) + + def __getitem__(self, idx): + (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] + + # read grayscale image and mask. (1, h, w) and (h, w) + img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) + img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) + + # TODO: Support augmentation & handle seeds for each worker correctly. + image0, mask0, scale0 = read_megadepth_gray( + img_name0, self.img_resize, self.df, self.img_padding, None) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + image1, mask1, scale1 = read_megadepth_gray( + img_name1, self.img_resize, self.df, self.img_padding, None) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + + # read depth. shape: (h, w) + if self.mode in ['train', 'val']: + depth0 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) + depth1 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) + else: + depth0 = depth1 = torch.tensor([]) + + # read intrinsics of original size + K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) + K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T0 = self.scene_info['poses'][idx0] + T1 = self.scene_info['poses'][idx1] + T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) + T_1to0 = T_0to1.inverse() + + if self.fp16: + image0, image1, depth0, depth1, scale0, scale1 = map(lambda x: x.half(), + [image0, image1, depth0, depth1, scale0, scale1]) + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'MegaDepth', + 'scene_id': self.scene_id, + 'pair_id': idx, + 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), + } + # for LoFTR training + if mask0 is not None: # img_padding is True + if self.coarse_scale: + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/EfficientLoFTR/src/datasets/sampler.py b/imcui/third_party/EfficientLoFTR/src/datasets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/datasets/sampler.py @@ -0,0 +1,77 @@ +import torch +from torch.utils.data import Sampler, ConcatDataset + + +class RandomConcatSampler(Sampler): + """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset + in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. + However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. + + For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. + Args: + shuffle (bool): shuffle the random sampled indices across all sub-datsets. + repeat (int): repeatedly use the sampled indices multiple times for training. + [arXiv:1902.05509, arXiv:1901.09335] + NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) + NOTE: This sampler behaves differently with DistributedSampler. + It assume the dataset is splitted across ranks instead of replicated. + TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. + ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 + """ + def __init__(self, + data_source: ConcatDataset, + n_samples_per_subset: int, + subset_replacement: bool=True, + shuffle: bool=True, + repeat: int=1, + seed: int=None): + if not isinstance(data_source, ConcatDataset): + raise TypeError("data_source should be torch.utils.data.ConcatDataset") + + self.data_source = data_source + self.n_subset = len(self.data_source.datasets) + self.n_samples_per_subset = n_samples_per_subset + self.n_samples = self.n_subset * self.n_samples_per_subset * repeat + self.subset_replacement = subset_replacement + self.repeat = repeat + self.shuffle = shuffle + self.generator = torch.manual_seed(seed) + assert self.repeat >= 1 + + def __len__(self): + return self.n_samples + + def __iter__(self): + indices = [] + # sample from each sub-dataset + for d_idx in range(self.n_subset): + low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] + high = self.data_source.cumulative_sizes[d_idx] + if self.subset_replacement: + rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), + generator=self.generator, dtype=torch.int64) + else: # sample without replacement + len_subset = len(self.data_source.datasets[d_idx]) + rand_tensor = torch.randperm(len_subset, generator=self.generator) + low + if len_subset >= self.n_samples_per_subset: + rand_tensor = rand_tensor[:self.n_samples_per_subset] + else: # padding with replacement + rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), + generator=self.generator, dtype=torch.int64) + rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) + indices.append(rand_tensor) + indices = torch.cat(indices) + if self.shuffle: # shuffle the sampled dataset (from multiple subsets) + rand_tensor = torch.randperm(len(indices), generator=self.generator) + indices = indices[rand_tensor] + + # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) + if self.repeat > 1: + repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] + if self.shuffle: + _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] + repeat_indices = map(_choice, repeat_indices) + indices = torch.cat([indices, *repeat_indices], 0) + + assert indices.shape[0] == self.n_samples + return iter(indices.tolist()) diff --git a/imcui/third_party/EfficientLoFTR/src/datasets/scannet.py b/imcui/third_party/EfficientLoFTR/src/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..41743aaa0b0f6827c116ab6166ae71964515d196 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/datasets/scannet.py @@ -0,0 +1,129 @@ +from os import path as osp +from typing import Dict +from unicodedata import name + +import numpy as np +import torch +import torch.utils as utils +from numpy.linalg import inv +from src.utils.dataset import ( + read_scannet_gray, + read_scannet_depth, + read_scannet_pose, + read_scannet_intrinsic +) + + +class ScanNetDataset(utils.data.Dataset): + def __init__(self, + root_dir, + npz_path, + intrinsic_path, + mode='train', + min_overlap_score=0.4, + augment_fn=None, + pose_dir=None, + img_resize=None, + fp16=False, + **kwargs): + """Manage one scene of ScanNet Dataset. + Args: + root_dir (str): ScanNet root directory that contains scene folders. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + intrinsic_path (str): path to depth-camera intrinsic file. + mode (str): options are ['train', 'val', 'test']. + augment_fn (callable, optional): augments images with pre-defined visual effects. + pose_dir (str): ScanNet root directory that contains all poses. + (we use a separate (optional) pose_dir since we store images and poses separately.) + """ + super().__init__() + self.root_dir = root_dir + self.pose_dir = pose_dir if pose_dir is not None else root_dir + self.mode = mode + + # prepare data_names, intrinsics and extrinsics(T) + with np.load(npz_path) as data: + self.data_names = data['name'] + if 'score' in data.keys() and mode not in ['val' or 'test']: + kept_mask = data['score'] > min_overlap_score + self.data_names = self.data_names[kept_mask] + self.intrinsics = dict(np.load(intrinsic_path)) + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + + self.fp16 = fp16 + self.img_resize = img_resize + + def __len__(self): + return len(self.data_names) + + def _read_abs_pose(self, scene_name, name): + pth = osp.join(self.pose_dir, + scene_name, + 'pose', f'{name}.txt') + return read_scannet_pose(pth) + + def _compute_rel_pose(self, scene_name, name0, name1): + pose0 = self._read_abs_pose(scene_name, name0) + pose1 = self._read_abs_pose(scene_name, name1) + + return np.matmul(pose1, inv(pose0)) # (4, 4) + + def __getitem__(self, idx): + data_name = self.data_names[idx] + scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the grayscale image which will be resized to (1, 480, 640) + img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') + img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') + + # TODO: Support augmentation & handle seeds for each worker correctly. + image0 = read_scannet_gray(img_name0, resize=self.img_resize, augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + image1 = read_scannet_gray(img_name1, resize=self.img_resize, augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + + # read the depthmap which is stored as (480, 640) + if self.mode in ['train', 'val']: + depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) + depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) + else: + depth0 = depth1 = torch.tensor([]) + + # read the intrinsic of depthmap + K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), + dtype=torch.float32) + T_1to0 = T_0to1.inverse() + + h_new, w_new = self.img_resize[1], self.img_resize[0] + scale0 = torch.tensor([640/w_new, 480/h_new], dtype=torch.float) + scale1 = torch.tensor([640/w_new, 480/h_new], dtype=torch.float) + + if self.fp16: + image0, image1, depth0, depth1, scale0, scale1 = map(lambda x: x.half(), + [image0, image1, depth0, depth1, scale0, scale1]) + + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'ScanNet', + 'scene_id': scene_name, + 'pair_id': idx, + 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), + osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) + } + + return data \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/lightning/data.py b/imcui/third_party/EfficientLoFTR/src/lightning/data.py new file mode 100644 index 0000000000000000000000000000000000000000..28730fc6bf30fcd99a35b6708e26b7a9a1eca9df --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/lightning/data.py @@ -0,0 +1,357 @@ +import os +import math +from collections import abc +from loguru import logger +from torch.utils.data.dataset import Dataset +from tqdm import tqdm +from os import path as osp +from pathlib import Path +from joblib import Parallel, delayed + +import pytorch_lightning as pl +from torch import distributed as dist +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset, + DistributedSampler, + RandomSampler, + dataloader +) + +from src.utils.augment import build_augmentor +from src.utils.dataloader import get_local_split +from src.utils.misc import tqdm_joblib +from src.utils import comm +from src.datasets.megadepth import MegaDepthDataset +from src.datasets.scannet import ScanNetDataset +from src.datasets.sampler import RandomConcatSampler + + +class MultiSceneDataModule(pl.LightningDataModule): + """ + For distributed training, each training process is assgined + only a part of the training scenes to reduce memory overhead. + """ + def __init__(self, args, config): + super().__init__() + + # 1. data config + # Train and Val should from the same data source + self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE + self.test_data_source = config.DATASET.TEST_DATA_SOURCE + # training and validating + self.train_data_root = config.DATASET.TRAIN_DATA_ROOT + self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional) + self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT + self.train_list_path = config.DATASET.TRAIN_LIST_PATH + self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH + self.val_data_root = config.DATASET.VAL_DATA_ROOT + self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional) + self.val_npz_root = config.DATASET.VAL_NPZ_ROOT + self.val_list_path = config.DATASET.VAL_LIST_PATH + self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH + # testing + self.test_data_root = config.DATASET.TEST_DATA_ROOT + self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) + self.test_npz_root = config.DATASET.TEST_NPZ_ROOT + self.test_list_path = config.DATASET.TEST_LIST_PATH + self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH + + # 2. dataset config + # general options + self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN + self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] + + # ScanNet options + self.scan_img_resizeX = config.DATASET.SCAN_IMG_RESIZEX # 640 + self.scan_img_resizeY = config.DATASET.SCAN_IMG_RESIZEY # 480 + + + # MegaDepth options + self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 832 + self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True + self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True + self.mgdpt_df = config.DATASET.MGDPT_DF # 8 + self.coarse_scale = 1 / config.LOFTR.RESOLUTION[0] # 0.125. for training loftr. + + self.fp16 = config.DATASET.FP16 + + # 3.loader parameters + self.train_loader_params = { + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.val_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.test_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': True + } + + # 4. sampler + self.data_sampler = config.TRAINER.DATA_SAMPLER + self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET + self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT + self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE + self.repeat = config.TRAINER.SB_REPEAT + + # (optional) RandomSampler for debugging + + # misc configurations + self.parallel_load_data = getattr(args, 'parallel_load_data', False) + self.seed = config.TRAINER.SEED # 66 + + def setup(self, stage=None): + """ + Setup train / val / test dataset. This method will be called by PL automatically. + Args: + stage (str): 'fit' in training phase, and 'test' in testing phase. + """ + + assert stage in ['fit', 'validate', 'test'], "stage must be either fit or test" + + try: + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") + except AssertionError as ae: + self.world_size = 1 + self.rank = 0 + # logger.warning(" (set wolrd_size=1 and rank=0)") + logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") + + if stage == 'fit': + self.train_dataset = self._setup_dataset( + self.train_data_root, + self.train_npz_root, + self.train_list_path, + self.train_intrinsic_path, + mode='train', + min_overlap_score=self.min_overlap_score_train, + pose_dir=self.train_pose_root) + # setup multiple (optional) validation subsets + if isinstance(self.val_list_path, (list, tuple)): + self.val_dataset = [] + if not isinstance(self.val_npz_root, (list, tuple)): + self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] + for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): + self.val_dataset.append(self._setup_dataset( + self.val_data_root, + npz_root, + npz_list, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root)) + else: + self.val_dataset = self._setup_dataset( + self.val_data_root, + self.val_npz_root, + self.val_list_path, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root) + logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') + elif stage == 'validate': + if isinstance(self.val_list_path, (list, tuple)): + self.val_dataset = [] + if not isinstance(self.val_npz_root, (list, tuple)): + self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] + for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): + self.val_dataset.append(self._setup_dataset( + self.val_data_root, + npz_root, + npz_list, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root)) + else: + self.val_dataset = self._setup_dataset( + self.val_data_root, + self.val_npz_root, + self.val_list_path, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root) + logger.info(f'[rank:{self.rank}] Val Dataset loaded!') + else: # stage == 'test + self.test_dataset = self._setup_dataset( + self.test_data_root, + self.test_npz_root, + self.test_list_path, + self.test_intrinsic_path, + mode='test', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.test_pose_root) + logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') + + def _setup_dataset(self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode='train', + min_overlap_score=0., + pose_dir=None): + """ Setup train / val / test set""" + with open(scene_list_path, 'r') as f: + npz_names = [name.split()[0] for name in f.readlines()] + + if mode == 'train': + local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) + else: + local_npz_names = npz_names + logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') + + dataset_builder = self._build_concat_dataset_parallel \ + if self.parallel_load_data \ + else self._build_concat_dataset + return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, + mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + + def _build_concat_dataset( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None + ): + datasets = [] + augment_fn = self.augment_fn if mode == 'train' else None + data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + for npz_name in tqdm(npz_names, + desc=f'[rank:{self.rank}] loading {mode} datasets', + disable=int(self.rank) != 0): + # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. + npz_path = osp.join(npz_dir, npz_name) + if data_source == 'ScanNet': + datasets.append( + ScanNetDataset(data_root, + npz_path, + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir, + img_resize=(self.scan_img_resizeX, self.scan_img_resizeY), + fp16 = self.fp16, + )) + elif data_source == 'MegaDepth': + datasets.append( + MegaDepthDataset(data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale, + fp16 = self.fp16, + )) + else: + raise NotImplementedError() + return ConcatDataset(datasets) + + def _build_concat_dataset_parallel( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None, + ): + augment_fn = self.augment_fn if mode == 'train' else None + data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', + total=len(npz_names), disable=int(self.rank) != 0)): + if data_source == 'ScanNet': + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + ScanNetDataset, + data_root, + osp.join(npz_dir, x), + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir))(name) + for name in npz_names) + elif data_source == 'MegaDepth': + # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. + raise NotImplementedError() + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + MegaDepthDataset, + data_root, + osp.join(npz_dir, x), + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale))(name) + for name in npz_names) + else: + raise ValueError(f'Unknown dataset: {data_source}') + return ConcatDataset(datasets) + + def train_dataloader(self): + """ Build training dataloader for ScanNet / MegaDepth. """ + assert self.data_sampler in ['scene_balance'] + logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') + if self.data_sampler == 'scene_balance': + sampler = RandomConcatSampler(self.train_dataset, + self.n_samples_per_subset, + self.subset_replacement, + self.shuffle, self.repeat, self.seed) + else: + sampler = None + dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) + return dataloader + + def val_dataloader(self): + """ Build validation dataloader for ScanNet / MegaDepth. """ + logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') + if not isinstance(self.val_dataset, abc.Sequence): + sampler = DistributedSampler(self.val_dataset, shuffle=False) + return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) + else: + dataloaders = [] + for dataset in self.val_dataset: + sampler = DistributedSampler(dataset, shuffle=False) + dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) + return dataloaders + + def test_dataloader(self, *args, **kwargs): + logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') + sampler = DistributedSampler(self.test_dataset, shuffle=False) + return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) + + +def _build_dataset(dataset: Dataset, *args, **kwargs): + return dataset(*args, **kwargs) diff --git a/imcui/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py b/imcui/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a8e3ef2725e63b633c6022ae5bfd1a138e438b --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py @@ -0,0 +1,272 @@ + +from collections import defaultdict +import pprint +from loguru import logger +from pathlib import Path + +import torch +import numpy as np +import pytorch_lightning as pl +from matplotlib import pyplot as plt + +from src.loftr import LoFTR +from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine +from src.losses.loftr_loss import LoFTRLoss +from src.optimizers import build_optimizer, build_scheduler +from src.utils.metrics import ( + compute_symmetrical_epipolar_errors, + compute_pose_errors, + aggregate_metrics +) +from src.utils.plotting import make_matching_figures +from src.utils.comm import gather, all_gather +from src.utils.misc import lower_config, flattenList +from src.utils.profiler import PassThroughProfiler + +from torch.profiler import profile + +def reparameter(matcher): + module = matcher.backbone.layer0 + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]: + for module in modules: + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]: + for module in modules: + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + return matcher + + +class PL_LoFTR(pl.LightningModule): + def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): + """ + TODO: + - use the new version of PL logging API. + """ + super().__init__() + # Misc + self.config = config # full config + _config = lower_config(self.config) + self.loftr_cfg = lower_config(_config['loftr']) + self.profiler = profiler or PassThroughProfiler() + self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + + # Matcher: LoFTR + self.matcher = LoFTR(config=_config['loftr'], profiler=self.profiler) + self.loss = LoFTRLoss(_config) + + # Pretrained weights + if pretrained_ckpt: + state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] + msg=self.matcher.load_state_dict(state_dict, strict=False) + logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") + + # Testing + self.warmup = False + self.reparameter = False + self.start_event = torch.cuda.Event(enable_timing=True) + self.end_event = torch.cuda.Event(enable_timing=True) + self.total_ms = 0 + + def configure_optimizers(self): + # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` + optimizer = build_optimizer(self, self.config) + scheduler = build_scheduler(self.config, optimizer) + return [optimizer], [scheduler] + + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # learning rate warm up + warmup_step = self.config.TRAINER.WARMUP_STEP + if self.trainer.global_step < warmup_step: + if self.config.TRAINER.WARMUP_TYPE == 'linear': + base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR + lr = base_lr + \ + (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ + abs(self.config.TRAINER.TRUE_LR - base_lr) + for pg in optimizer.param_groups: + pg['lr'] = lr + elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pass + else: + raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + + # update params + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() + + def _trainval_inference(self, batch): + with self.profiler.profile("Compute coarse supervision"): + with torch.autocast(enabled=False, device_type='cuda'): + compute_supervision_coarse(batch, self.config) + + with self.profiler.profile("LoFTR"): + with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'): + self.matcher(batch) + + with self.profiler.profile("Compute fine supervision"): + with torch.autocast(enabled=False, device_type='cuda'): + compute_supervision_fine(batch, self.config, self.logger) + + with self.profiler.profile("Compute losses"): + with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'): + self.loss(batch) + + def _compute_metrics(self, batch): + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair + + rel_pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [(batch['epi_errs'].reshape(-1,1))[batch['m_bids'] == b].reshape(-1).cpu().numpy() for b in range(bs)], + 'R_errs': batch['R_errs'], + 't_errs': batch['t_errs'], + 'inliers': batch['inliers'], + 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only + } + ret_dict = {'metrics': metrics} + return ret_dict, rel_pair_names + + def training_step(self, batch, batch_idx): + self._trainval_inference(batch) + + # logging + if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + # scalars + for k, v in batch['loss_scalars'].items(): + self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step) + + # figures + if self.config.TRAINER.ENABLE_PLOTTING: + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + for k, v in figures.items(): + self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step) + return {'loss': batch['loss']} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + if self.trainer.global_rank == 0: + self.logger.experiment.add_scalar( + 'train/avg_loss_on_epoch', avg_loss, + global_step=self.current_epoch) + + def on_validation_epoch_start(self): + self.matcher.fine_matching.validate = True + + def validation_step(self, batch, batch_idx): + self._trainval_inference(batch) + + ret_dict, _ = self._compute_metrics(batch) + + val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) + figures = {self.config.TRAINER.PLOT_MODE: []} + if batch_idx % val_plot_interval == 0: + figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE) + + return { + **ret_dict, + 'loss_scalars': batch['loss_scalars'], + 'figures': figures, + } + + def validation_epoch_end(self, outputs): + self.matcher.fine_matching.validate = False + # handle multiple validation sets + multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + multi_val_metrics = defaultdict(list) + + for valset_idx, outputs in enumerate(multi_outputs): + # since pl performs sanity_check at the very begining of the training + cur_epoch = self.trainer.current_epoch + if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + cur_epoch = -1 + + # 1. loss_scalars: dict of list, on cpu + _loss_scalars = [o['loss_scalars'] for o in outputs] + loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + + # 2. val metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, config=self.config) + for thr in [5, 10, 20]: + multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}']) + + # 3. figures + _figures = [o['figures'] for o in outputs] + figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} + + # tensorboard records only on rank 0 + if self.trainer.global_rank == 0: + for k, v in loss_scalars.items(): + mean_v = torch.stack(v).mean() + self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + + for k, v in val_metrics_4tb.items(): + self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch) + + for k, v in figures.items(): + if self.trainer.global_rank == 0: + for plot_idx, fig in enumerate(v): + self.logger.experiment.add_figure( + f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True) + plt.close('all') + + for thr in [5, 10, 20]: + # log on all ranks for ModelCheckpoint callback to work properly + self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this + + def test_step(self, batch, batch_idx): + if (self.config.LOFTR.BACKBONE_TYPE == 'RepVGG') and not self.reparameter: + self.matcher = reparameter(self.matcher) + if self.config.LOFTR.HALF: + self.matcher = self.matcher.eval().half() + self.reparameter = True + + if not self.warmup: + if self.config.LOFTR.HALF: + for i in range(50): + self.matcher(batch) + else: + with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'): + for i in range(50): + self.matcher(batch) + self.warmup = True + torch.cuda.synchronize() + + if self.config.LOFTR.HALF: + self.start_event.record() + self.matcher(batch) + self.end_event.record() + torch.cuda.synchronize() + self.total_ms += self.start_event.elapsed_time(self.end_event) + else: + with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'): + self.start_event.record() + self.matcher(batch) + self.end_event.record() + torch.cuda.synchronize() + self.total_ms += self.start_event.elapsed_time(self.end_event) + + ret_dict, rel_pair_names = self._compute_metrics(batch) + return ret_dict + + def test_epoch_end(self, outputs): + # metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + + # [{key: [{...}, *#bs]}, *#batch] + if self.trainer.global_rank == 0: + print('Averaged Matching time over 1500 pairs: {:.2f} ms'.format(self.total_ms / 1500)) + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, config=self.config) + logger.info('\n' + pprint.pformat(val_metrics_4tb)) \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/__init__.py b/imcui/third_party/EfficientLoFTR/src/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..362c0d04fa437c0073016a9ceac607ec5cffdfb7 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/__init__.py @@ -0,0 +1,4 @@ +from .loftr import LoFTR +from .utils.full_config import full_default_cfg +from .utils.opt_config import opt_default_cfg +from .loftr import reparameter \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2c02f486958a542550c0943ce284cf97e44050 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py @@ -0,0 +1,11 @@ +from .backbone import RepVGG_8_1_align + +def build_backbone(config): + if config['backbone_type'] == 'RepVGG': + if config['align_corner'] is False: + if config['resolution'] == (8, 1): + return RepVGG_8_1_align(config['backbone']) + else: + raise ValueError(f"LOFTR.ALIGN_CORNER {config['align_corner']} not supported.") + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d921971914d4d465e5e4c7fe87c90e773d4ef1 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py @@ -0,0 +1,37 @@ +import torch.nn as nn +import torch.nn.functional as F +from .repvgg import create_RepVGG + +class RepVGG_8_1_align(nn.Module): + """ + RepVGG backbone, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + backbone = create_RepVGG(False) + + self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 + + for layer in [self.layer0, self.layer1, self.layer2, self.layer3]: + for m in layer.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + out = self.layer0(x) # 1/2 + for module in self.layer1: + out = module(out) # 1/2 + x1 = out + for module in self.layer2: + out = module(out) # 1/4 + x2 = out + for module in self.layer3: + out = module(out) # 1/8 + x3 = out + + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py new file mode 100644 index 0000000000000000000000000000000000000000..45b038c82511e902947b65d7c63ff461dc24c3e7 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py @@ -0,0 +1,224 @@ +# -------------------------------------------------------- +# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf) +# Github source: https://github.com/DingXiaoH/RepVGG +# Licensed under The MIT License [see LICENSE for details] +# Modified from: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py +# -------------------------------------------------------- +import torch.nn as nn +import numpy as np +import torch +import copy +# from se_block import SEBlock +import torch.utils.checkpoint as checkpoint +from loguru import logger + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + result = nn.Sequential() + result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + +class RepVGGBlock(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False): + super(RepVGGBlock, self).__init__() + self.deploy = deploy + self.groups = groups + self.in_channels = in_channels + + assert kernel_size == 3 + assert padding == 1 + + padding_11 = padding - kernel_size // 2 + + self.nonlinearity = nn.ReLU() + + if use_se: + # Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity. + # self.se = SEBlock(out_channels, internal_neurons=out_channels // 16) + raise ValueError(f"SEBlock not supported") + else: + self.se = nn.Identity() + + if deploy: + self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) + else: + self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None + self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) + self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups) + + def forward(self, inputs): + if hasattr(self, 'rbr_reparam'): + return self.nonlinearity(self.se(self.rbr_reparam(inputs))) + + if self.rbr_identity is None: + id_out = 0 + else: + id_out = self.rbr_identity(inputs) + + return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) + + + # Optional. This may improve the accuracy and facilitates quantization in some cases. + # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight. + # 2. Use like this. + # loss = criterion(....) + # for every RepVGGBlock blk: + # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2() + # optimizer.zero_grad() + # loss.backward() + def get_custom_L2(self): + K3 = self.rbr_dense.conv.weight + K1 = self.rbr_1x1.conv.weight + t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() + t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() + + l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them. + eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel. + l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2. + return l2_loss_eq_kernel + l2_loss_circle + + + +# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. +# You can get the equivalent kernel and bias at any time and do whatever you want, + # for example, apply some penalties or constraints during training, just like you do to the other models. +# May be useful for quantization or pruning. + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1,1,1,1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def switch_to_deploy(self): + if hasattr(self, 'rbr_reparam'): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels, + kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, + padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True) + self.rbr_reparam.weight.data = kernel + self.rbr_reparam.bias.data = bias + self.__delattr__('rbr_dense') + self.__delattr__('rbr_1x1') + if hasattr(self, 'rbr_identity'): + self.__delattr__('rbr_identity') + if hasattr(self, 'id_tensor'): + self.__delattr__('id_tensor') + self.deploy = True + + + +class RepVGG(nn.Module): + + def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False): + super(RepVGG, self).__init__() + assert len(width_multiplier) == 4 + self.deploy = deploy + self.override_groups_map = override_groups_map or dict() + assert 0 not in self.override_groups_map + self.use_se = use_se + self.use_checkpoint = use_checkpoint + + self.in_planes = min(64, int(64 * width_multiplier[0])) + self.stage0 = RepVGGBlock(in_channels=1, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se) + self.cur_layer_idx = 1 + self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1) + self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2) + self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2) + + def _make_stage(self, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + blocks = [] + for stride in strides: + cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) + blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, + stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se)) + self.in_planes = planes + self.cur_layer_idx += 1 + return nn.ModuleList(blocks) + + def forward(self, x): + out = self.stage0(x) + for stage in (self.stage1, self.stage2, self.stage3): + for block in stage: + if self.use_checkpoint: + out = checkpoint.checkpoint(block, out) + else: + out = block(out) + out = self.gap(out) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] +g2_map = {l: 2 for l in optional_groupwise_layers} +g4_map = {l: 4 for l in optional_groupwise_layers} + +def create_RepVGG(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) + +# Use this for converting a RepVGG model or a bigger model with RepVGG as its component +# Use like this +# model = create_RepVGG_A0(deploy=False) +# train model or load weights +# repvgg_model_convert(model, save_path='repvgg_deploy.pth') +# If you want to preserve the original model, call with do_copy=True + +# ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like +# train_backbone = create_RepVGG_B2(deploy=False) +# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth')) +# train_pspnet = build_pspnet(backbone=train_backbone) +# segmentation_train(train_pspnet) +# deploy_pspnet = repvgg_model_convert(train_pspnet) +# segmentation_test(deploy_pspnet) +# ===================== example_pspnet.py shows an example + +def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True): + if do_copy: + model = copy.deepcopy(model) + for module in model.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + if save_path is not None: + torch.save(model.state_dict(), save_path) + return model diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..8f76939a1b0c68504f535d4c9eb4ef91d19cd63a --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr.py @@ -0,0 +1,124 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .loftr_module import LocalFeatureTransformer, FinePreprocess +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching +from ..utils.misc import detect_NaN + +from loguru import logger + +def reparameter(matcher): + module = matcher.backbone.layer0 + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]: + for module in modules: + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]: + for module in modules: + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + return matcher + +class LoFTR(nn.Module): + def __init__(self, config, profiler=None): + super().__init__() + # Misc + self.config = config + self.profiler = profiler + + # Modules + self.backbone = build_backbone(config) + self.loftr_coarse = LocalFeatureTransformer(config) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.fine_matching = FineMatching(config) + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + ret_dict = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + feats_c = ret_dict['feats_c'] + data.update({ + 'feats_x2': ret_dict['feats_x2'], + 'feats_x1': ret_dict['feats_x1'], + }) + (feat_c0, feat_c1) = feats_c.split(data['bs']) + else: # handle different input shapes + ret_dict0, ret_dict1 = self.backbone(data['image0']), self.backbone(data['image1']) + feat_c0 = ret_dict0['feats_c'] + feat_c1 = ret_dict1['feats_c'] + data.update({ + 'feats_x2_0': ret_dict0['feats_x2'], + 'feats_x1_0': ret_dict0['feats_x1'], + 'feats_x2_1': ret_dict1['feats_x2'], + 'feats_x1_1': ret_dict1['feats_x1'], + }) + + + mul = self.config['resolution'][0] // self.config['resolution'][1] + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': [feat_c0.shape[2] * mul, feat_c0.shape[3] * mul] , + 'hw1_f': [feat_c1.shape[2] * mul, feat_c1.shape[3] * mul] + }) + + # 2. coarse-level loftr module + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'], data['mask1'] + + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + + # detect NaN during mixed precision training + if self.config['replace_nan'] and (torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))): + detect_NaN(feat_c0, feat_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, + mask_c0=mask_c0.view(mask_c0.size(0), -1) if mask_c0 is not None else mask_c0, + mask_c1=mask_c1.view(mask_c1.size(0), -1) if mask_c1 is not None else mask_c1 + ) + + # prevent fp16 overflow during mixed precision training + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_c0, feat_c1, data) + + # detect NaN during mixed precision training + if self.config['replace_nan'] and (torch.any(torch.isnan(feat_f0_unfold)) or torch.any(torch.isnan(feat_f1_unfold))): + detect_NaN(feat_f0_unfold, feat_f1_unfold) + + del feat_c0, feat_c1, mask_c0, mask_c1 + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..ca37e02a0e709650d8133db04e84e9cfccbd6bf0 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + +from loguru import logger + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + block_dims = config['backbone']['block_dims'] + self.W = self.config['fine_window_size'] + self.fine_d_model = block_dims[0] + + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def inter_fpn(self, feat_c, x2, x1, stride): + feat_c = self.layer3_outconv(feat_c) + feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False) + + x2 = self.layer2_outconv(x2) + x2 = self.layer2_outconv2(x2+feat_c) + x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False) + + x1 = self.layer1_outconv(x1) + x1 = self.layer1_outconv2(x1+x2) + x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False) + return x1 + + def forward(self, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.fine_d_model, device=feat_c0.device) + feat1 = torch.empty(0, self.W**2, self.fine_d_model, device=feat_c0.device) + return feat0, feat1 + + if data['hw0_i'] == data['hw1_i']: + feat_c = rearrange(torch.cat([feat_c0, feat_c1], 0), 'b (h w) c -> b c h w', h=data['hw0_c'][0]) # 1/8 feat + x2 = data['feats_x2'] # 1/4 feat + x1 = data['feats_x1'] # 1/2 feat + del data['feats_x2'], data['feats_x1'] + + # 1. fine feature extraction + x1 = self.inter_fpn(feat_c, x2, x1, stride) + feat_f0, feat_f1 = torch.chunk(x1, 2, dim=0) + + # 2. unfold(crop) all local windows + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2) + + # 3. select only the predicted matches + feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1 = feat_f1[data['b_ids'], data['j_ids']] + + return feat_f0, feat_f1 + else: # handle different input shapes + feat_c0, feat_c1 = rearrange(feat_c0, 'b (h w) c -> b c h w', h=data['hw0_c'][0]), rearrange(feat_c1, 'b (h w) c -> b c h w', h=data['hw1_c'][0]) # 1/8 feat + x2_0, x2_1 = data['feats_x2_0'], data['feats_x2_1'] # 1/4 feat + x1_0, x1_1 = data['feats_x1_0'], data['feats_x1_1'] # 1/2 feat + del data['feats_x2_0'], data['feats_x1_0'], data['feats_x2_1'], data['feats_x1_1'] + + # 1. fine feature extraction + feat_f0, feat_f1 = self.inter_fpn(feat_c0, x2_0, x1_0, stride), self.inter_fpn(feat_c1, x2_1, x1_1, stride) + + # 2. unfold(crop) all local windows + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2) + + # 3. select only the predicted matches + feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1 = feat_f1[data['b_ids'], data['j_ids']] + + return feat_f0, feat_f1 \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4d95414e35441522c7df65c88a403672d3aa227b --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py @@ -0,0 +1,103 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module +import torch.nn.functional as F +from einops.einops import rearrange + +if hasattr(F, 'scaled_dot_product_attention'): + FLASH_AVAILABLE = True + from torch.backends.cuda import sdp_kernel +else: + FLASH_AVAILABLE = False + +def crop_feature(query, key, value, x_mask, source_mask): + mask_h0, mask_w0, mask_h1, mask_w1 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0], source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0] + query = query[:, :mask_h0, :mask_w0, :] + key = key[:, :mask_h1, :mask_w1, :] + value = value[:, :mask_h1, :mask_w1, :] + return query, key, value, mask_h0, mask_w0 + +def pad_feature(m, mask_h0, mask_w0, x_mask): + bs, L, H, D = m.size() + m = m.view(bs, mask_h0, mask_w0, H, D) + if mask_h0 != x_mask.size(-2): + m = torch.cat([m, torch.zeros(m.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), H, D, device=m.device, dtype=m.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + m = torch.cat([m, torch.zeros(m.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, H, D, device=m.device, dtype=m.dtype)], dim=2) + return m + +class Attention(Module): + def __init__(self, no_flash=False, nhead=8, dim=256, fp32=False): + super().__init__() + self.flash = FLASH_AVAILABLE and not no_flash + self.nhead = nhead + self.dim = dim + self.fp32 = fp32 + + def attention(self, query, key, value, q_mask=None, kv_mask=None): + assert q_mask is None and kv_mask is None, "Not support generalized attention mask yet." + if self.flash and not self.fp32: + args = [x.contiguous() for x in [query, key, value]] + with sdp_kernel(enable_math= False, enable_flash= True, enable_mem_efficient= False): + out = F.scaled_dot_product_attention(*args) + elif self.flash: + args = [x.contiguous() for x in [query, key, value]] + out = F.scaled_dot_product_attention(*args) + else: + QK = torch.einsum("nlhd,nshd->nlsh", query, key) + + # Compute the attention and the weighted average + softmax_temp = 1. / query.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + + out = torch.einsum("nlsh,nshd->nlhd", A, value) + return out + + def _forward(self, query, key, value, q_mask=None, kv_mask=None): + if q_mask is not None: + query, key, value, mask_h0, mask_w0 = crop_feature(query, key, value, q_mask, kv_mask) + + if self.flash: + query, key, value = map(lambda x: rearrange(x, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim), [query, key, value]) + else: + query, key, value = map(lambda x: rearrange(x, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim), [query, key, value]) + + m = self.attention(query, key, value, q_mask=None, kv_mask=None) + + if self.flash: + m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) + + if q_mask is not None: + m = pad_feature(m, mask_h0, mask_w0, q_mask) + + return m + + def forward(self, query, key, value, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + if FLASH_AVAILABLE: # pytorch scaled_dot_product_attention + queries: [N, H, L, D] + keys: [N, H, S, D] + values: [N, H, S, D] + else: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + bs = query.size(0) + if bs == 1 or q_mask is None: + m = self._forward(query, key, value, q_mask=q_mask, kv_mask=kv_mask) + else: # for faster trainning with padding mask while batch size > 1 + m_list = [] + for i in range(bs): + m_list.append(self._forward(query[i:i+1], key[i:i+1], value[i:i+1], q_mask=q_mask[i:i+1], kv_mask=kv_mask[i:i+1])) + m = torch.cat(m_list, dim=0) + return m \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e97a033a185049539a9f2fd29483333a839a3bcd --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py @@ -0,0 +1,164 @@ +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +from .linear_attention import Attention, crop_feature, pad_feature +from einops.einops import rearrange +from collections import OrderedDict +from ..utils.position_encoding import RoPEPositionEncodingSine +import numpy as np +from loguru import logger + +class AG_RoPE_EncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + agg_size0=4, + agg_size1=4, + no_flash=False, + rope=False, + npe=None, + fp32=False, + ): + super(AG_RoPE_EncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + self.agg_size0, self.agg_size1 = agg_size0, agg_size1 + self.rope = rope + + # aggregate and position encoding + self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=agg_size0, padding=0, stride=agg_size0, bias=False, groups=d_model) if self.agg_size0 != 1 else nn.Identity() + self.max_pool = torch.nn.MaxPool2d(kernel_size=self.agg_size1, stride=self.agg_size1) if self.agg_size1 != 1 else nn.Identity() + if self.rope: + self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True) + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = Attention(no_flash, self.nhead, self.dim, fp32) + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.LeakyReLU(inplace = True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, C, H0, W0] + source (torch.Tensor): [N, C, H1, W1] + x_mask (torch.Tensor): [N, H0, W0] (optional) (L = H0*W0) + source_mask (torch.Tensor): [N, H1, W1] (optional) (S = H1*W1) + """ + bs, C, H0, W0 = x.size() + H1, W1 = source.size(-2), source.size(-1) + + # Aggragate feature + query, source = self.norm1(self.aggregate(x).permute(0,2,3,1)), self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C] + if x_mask is not None: + x_mask, source_mask = map(lambda x: self.max_pool(x.float()).bool(), [x_mask, source_mask]) + query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source) + + # Positional encoding + if self.rope: + query = self.rope_pos_enc(query) + key = self.rope_pos_enc(key) + + # multi-head attention handle padding mask + m = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) + m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C] + + # Upsample feature + m = rearrange(m, 'b (h w) c -> b c h w', h=H0 // self.agg_size0, w=W0 // self.agg_size0) # [N, C, H0, W0] + if self.agg_size0 != 1: + m = torch.nn.functional.interpolate(m, scale_factor=self.agg_size0, mode='bilinear', align_corners=False) # [N, C, H0, W0] + + # feed-forward network + m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H0, W0, C] + m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H0, W0] + + return x + m + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.full_config = config + self.fp32 = not (config['mp'] or config['half']) + config = config['coarse'] + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + self.agg_size0, self.agg_size1 = config['agg_size0'], config['agg_size1'] + self.rope = config['rope'] + + self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'], + config['no_flash'], config['rope'], config['npe'], self.fp32) + cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'], + config['no_flash'], False, config['npe'], self.fp32) + self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None, data=None): + """ + Args: + feat0 (torch.Tensor): [N, C, H, W] + feat1 (torch.Tensor): [N, C, H, W] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1) + bs = feat0.shape[0] + + feature_cropped = False + if bs == 1 and mask0 is not None and mask1 is not None: + mask_H0, mask_W0, mask_H1, mask_W1 = mask0.size(-2), mask0.size(-1), mask1.size(-2), mask1.size(-1) + mask_h0, mask_w0, mask_h1, mask_w1 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0], mask1[0].sum(-2)[0], mask1[0].sum(-1)[0] + mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.agg_size0*self.agg_size0, mask_w0//self.agg_size0*self.agg_size0, mask_h1//self.agg_size1*self.agg_size1, mask_w1//self.agg_size1*self.agg_size1 + feat0 = feat0[:, :, :mask_h0, :mask_w0] + feat1 = feat1[:, :, :mask_h1, :mask_w1] + feature_cropped = True + + for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)): + if feature_cropped: + mask0, mask1 = None, None + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + if feature_cropped: + # padding feature + bs, c, mask_h0, mask_w0 = feat0.size() + if mask_h0 != mask_H0: + feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2) + elif mask_w0 != mask_W0: + feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1) + + bs, c, mask_h1, mask_w1 = feat1.size() + if mask_h1 != mask_H1: + feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2) + elif mask_w1 != mask_W1: + feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1) + + return feat0, feat1 \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..156c9eecf8c2cfb54b8eb22a8663d5cda5afa6a8 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + +from loguru import logger +import numpy as np + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + self.temperature = config['dsmax_temperature'] + self.skip_softmax = config['skip_softmax'] + self.fp16matmul = config['fp16matmul'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + if self.fp16matmul: + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + del feat_c0, feat_c1 + if mask_c0 is not None: + sim_matrix = sim_matrix.masked_fill( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -1e4 + ) + else: + with torch.autocast(enabled=False, device_type='cuda'): + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + del feat_c0, feat_c1 + if mask_c0 is not None: + sim_matrix = sim_matrix.float().masked_fill( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF + ) + if self.skip_softmax: + sim_matrix = sim_matrix + else: + sim_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + data.update({'conf_matrix': sim_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(sim_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + m_bids = b_ids[mconf != 0] + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'm_bids': m_bids, # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..6adb6f8c8a1c3d25babda3d5cbd79b44285c2eb9 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py @@ -0,0 +1,156 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + +from loguru import logger + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self, config): + super().__init__() + self.config = config + self.local_regress_temperature = config['match_fine']['local_regress_temperature'] + self.local_regress_slicedim = config['match_fine']['local_regress_slicedim'] + self.fp16 = config['half'] + self.validate = False + + def forward(self, feat_0, feat_1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always > 0 while training, see coarse_matching.py" + data.update({ + 'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + # compute pixel-level confidence matrix + with torch.autocast(enabled=True if not (self.training or self.validate) else False, device_type='cuda'): + feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim] + feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:] + feat_f0, feat_f1 = feat_f0 / C**.5, feat_f1 / C**.5 + conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1) + conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5) + + softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2) + softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2) + softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW) + + # for fine-level supervision + if self.training or self.validate: + data.update({'sim_matrix_ff': conf_matrix_ff}) + data.update({'conf_matrix_f': softmax_matrix_f}) + + # compute pixel-level absolute kpt coords + self.get_fine_ds_match(softmax_matrix_f, data) + + # generate seconde-stage 3x3 grid + idx_l, idx_r = data['idx_l'], data['idx_r'] + m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1) + m_ids = m_ids[:len(data['mconf'])] + idx_r_iids, idx_r_jids = idx_r // W, idx_r % W + + m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1) + delta = create_meshgrid(3, 3, True, conf_matrix_ff.device).to(torch.long) # [1, 3, 3, 2] + + m_ids = m_ids[...,None,None].expand(-1, 3, 3) + idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3] + + idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1] + idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0] + + if idx_l.numel() == 0: + data.update({ + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + # compute second-stage heatmap + conf_matrix_ff = conf_matrix_ff.reshape(M, self.WW, self.W+2, self.W+2) + conf_matrix_ff = conf_matrix_ff[m_ids, idx_l, idx_r_iids, idx_r_jids] + conf_matrix_ff = conf_matrix_ff.reshape(-1, 9) + conf_matrix_ff = F.softmax(conf_matrix_ff / self.local_regress_temperature, -1) + heatmap = conf_matrix_ff.reshape(-1, 3, 3) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + + if data['bs'] == 1: + scale1 = scale * data['scale1'] if 'scale0' in data else scale + else: + scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']), ...][:,None,:].expand(-1, -1, 2).reshape(-1, 2) if 'scale0' in data else scale + + # compute subpixel-level absolute kpt coords + self.get_fine_match_local(coords_normalized, data, scale1) + + def get_fine_match_local(self, coords_normed, data, scale1): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c'] + + # mkpts0_f and mkpts1_f + mkpts0_f = mkpts0_c + mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1) + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) + + @torch.no_grad() + def get_fine_ds_match(self, conf_matrix, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + m, _, _ = conf_matrix.shape + + conf_matrix = conf_matrix.reshape(m, -1)[:len(data['mconf']),...] + val, idx = torch.max(conf_matrix, dim = -1) + idx = idx[:,None] + idx_l, idx_r = idx // WW, idx % WW + + data.update({'idx_l': idx_l, 'idx_r': idx_r}) + + if self.fp16: + grid = create_meshgrid(W, W, False, conf_matrix.device, dtype=torch.float16) - W // 2 + 0.5 # kornia >= 0.5.1 + else: + grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5 + grid = grid.reshape(1, -1, 2).expand(m, -1, -1) + delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2)) + delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2)) + + scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + + if torch.is_tensor(scale0) and scale0.numel() > 1: # scale0 is a tensor + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2) + else: # scale0 is a float + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2) + + data.update({ + "mkpts0_c": mkpts0_f, + "mkpts1_c": mkpts1_f + }) \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/full_config.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/full_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf84b48f6693be4bf9306c8bd987d87a6f43792 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/full_config.py @@ -0,0 +1,50 @@ +from yacs.config import CfgNode as CN + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +_CN = CN() +_CN.BACKBONE_TYPE = 'RepVGG' +_CN.ALIGN_CORNER = False +_CN.RESOLUTION = (8, 1) +_CN.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even +_CN.MP = False +_CN.REPLACE_NAN = True +_CN.HALF = False + +# 1. LoFTR-backbone (local feature CNN) config +_CN.BACKBONE = CN() +_CN.BACKBONE.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.COARSE = CN() +_CN.COARSE.D_MODEL = 256 +_CN.COARSE.D_FFN = 256 +_CN.COARSE.NHEAD = 8 +_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.COARSE.AGG_SIZE0 = 4 +_CN.COARSE.AGG_SIZE1 = 4 +_CN.COARSE.NO_FLASH = False +_CN.COARSE.ROPE = True +_CN.COARSE.NPE = [832, 832, 832, 832] # [832, 832, long_side, long_side] Suggest setting based on the long side of the input image, especially when the long_side > 832 + +# 3. Coarse-Matching config +_CN.MATCH_COARSE = CN() +_CN.MATCH_COARSE.THR = 0.2 # recommend 0.2 for full model and 25 for optimized model +_CN.MATCH_COARSE.BORDER_RM = 2 +_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MATCH_COARSE.SKIP_SOFTMAX = False # False for full model and True for optimized model +_CN.MATCH_COARSE.FP16MATMUL = False # False for full model and True for optimized model +_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. Fine-Matching config +_CN.MATCH_FINE = CN() +_CN.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0 # use 10.0 as fine local regress temperature, not 1.0 +_CN.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8 + +full_default_cfg = lower_config(_CN) diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/geometry.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/geometry.py @@ -0,0 +1,54 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py new file mode 100644 index 0000000000000000000000000000000000000000..61b7fa1e88a72db226dbbbf3b47c2b4f40e7aff7 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py @@ -0,0 +1,50 @@ +from yacs.config import CfgNode as CN + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +_CN = CN() +_CN.BACKBONE_TYPE = 'RepVGG' +_CN.ALIGN_CORNER = False +_CN.RESOLUTION = (8, 1) +_CN.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even +_CN.MP = False +_CN.REPLACE_NAN = True +_CN.HALF = False + +# 1. LoFTR-backbone (local feature CNN) config +_CN.BACKBONE = CN() +_CN.BACKBONE.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.COARSE = CN() +_CN.COARSE.D_MODEL = 256 +_CN.COARSE.D_FFN = 256 +_CN.COARSE.NHEAD = 8 +_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.COARSE.AGG_SIZE0 = 4 +_CN.COARSE.AGG_SIZE1 = 4 +_CN.COARSE.NO_FLASH = False +_CN.COARSE.ROPE = True +_CN.COARSE.NPE = [832, 832, 832, 832] # [832, 832, long_side, long_side] Suggest setting based on the long side of the input image, especially when the long_side > 832 + +# 3. Coarse-Matching config +_CN.MATCH_COARSE = CN() +_CN.MATCH_COARSE.THR = 25 # recommend 0.2 for full model and 25 for optimized model +_CN.MATCH_COARSE.BORDER_RM = 2 +_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MATCH_COARSE.SKIP_SOFTMAX = True # False for full model and True for optimized model +_CN.MATCH_COARSE.FP16MATMUL = True # False for full model and True for optimized model +_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. Fine-Matching config +_CN.MATCH_FINE = CN() +_CN.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0 # use 10.0 as fine local regress temperature, not 1.0 +_CN.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8 + +opt_default_cfg = lower_config(_CN) diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..6431d0e2ce468fad5b0f0f6838c2ae2e5c089b32 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py @@ -0,0 +1,50 @@ +import math +import torch +from torch import nn + +class RoPEPositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1] + j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1] + + assert npe is not None + train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W + i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W + + div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4))) + div_term = div_term[None, None, :] # [1, 1, C//4] + + sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) + cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) + sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term) + sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term) + cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term) + cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term) + + sin = sin.repeat_interleave(2, dim=-1) + cos = cos.repeat_interleave(2, dim=-1) + + self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2] + self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2] + + def forward(self, x, ratio=1): + """ + Args: + x: [N, H, W, C] + """ + return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :]) + + def rotate_half(self, x): + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/supervision.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ae8036ef5a8108ddb7eab21bdc5efb26d356d8 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/supervision.py @@ -0,0 +1,275 @@ +from math import log +from loguru import logger as loguru_logger + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from kornia.utils import create_meshgrid +from src.utils.plotting import make_matching_figures + +from .geometry import warp_kpts + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + +def static_vars(**kwargs): + def decorate(func): + for k in kwargs: + setattr(func, k, kwargs[k]) + return func + return decorate + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['LOFTR']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (no depth consistency check, since it leads to worse results experimentally) + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round() + # calculate the overlap area between warped patch and grid patch as the loss weight. + # (larger overlap area between warped patches and grid patch with higher weight) + # (overlap area range from [0, 1] rather than [0.25, 1] as the penalty of warped kpts fall on midpoint of two grid kpts) + if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT: + w_pt0_c_error = (1.0 - 2*torch.abs(w_pt0_c - w_pt0_c_round)).prod(-1) + w_pt0_c_round = w_pt0_c_round[:, :, :].long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # use overlap area as loss weight + if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT: + conf_matrix_error_gt = w_pt0_c_error[b_ids, i_ids] # weight range: [0.0, 1.0] + data.update({'conf_matrix_error_gt': conf_matrix_error_gt}) + + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + loguru_logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +@static_vars(counter = 0) +@torch.no_grad() +def spvs_fine(data, config, logger = None): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2], used as subpixel-level gt + "conf_matrix_f_gt": [M, WW, WW], M is the number of all coarse-level gt matches + "conf_matrix_f_error_gt": [Mp], Mp is the number of all pixel-level gt matches + "m_ids_f": [Mp] + "i_ids_f": [Mp] + "j_ids_f_di": [Mp] + "j_ids_f_dj": [Mp] + } + """ + # 1. misc + pt1_i = data['spv_pt1_i'] + W = config['LOFTR']['FINE_WINDOW_SIZE'] + WW = W*W + scale = config['LOFTR']['RESOLUTION'][1] + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + hf0, wf0, hf1, wf1 = data['hw0_f'][0], data['hw0_f'][1], data['hw1_f'][0], data['hw1_f'][1] # h, w of fine feature + assert not config.LOFTR.ALIGN_CORNER, 'only support training with align_corner=False for now.' + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + scalei0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scalei1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + + # 3. compute gt + m = b_ids.shape[0] + if m == 0: # special case: there is no coarse gt + conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device) + + data.update({'conf_matrix_f_gt': conf_matrix_f_gt}) + if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT: + conf_matrix_f_error_gt = torch.zeros(1, device=device) + data.update({'conf_matrix_f_error_gt': conf_matrix_f_error_gt}) + + data.update({'expec_f': torch.zeros(1, 2, device=device)}) + data.update({'expec_f_gt': torch.zeros(1, 2, device=device)}) + else: + grid_pt0_f = create_meshgrid(hf0, wf0, False, device) - W // 2 + 0.5 # [1, hf0, wf0, 2] # use fine coordinates + grid_pt0_f = rearrange(grid_pt0_f, 'n h w c -> n c h w') + # 1. unfold(crop) all local windows + if config.LOFTR.ALIGN_CORNER is False: # even windows + assert W==8 + grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W, padding=0) + grid_pt0_f_unfold = rearrange(grid_pt0_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 2] + grid_pt0_f_unfold = repeat(grid_pt0_f_unfold[0], 'l ww c -> N l ww c', N=N) + + # 2. select only the predicted matches + grid_pt0_f_unfold = grid_pt0_f_unfold[data['b_ids'], data['i_ids']] # [m, ww, 2] + grid_pt0_f_unfold = scalei0[:,None,:] * grid_pt0_f_unfold # [m, ww, 2] + + # 3. warp grids and get covisible & depth_consistent mask + correct_0to1_f = torch.zeros(m, WW, device=device, dtype=torch.bool) + w_pt0_i = torch.zeros(m, WW, 2, device=device, dtype=torch.float32) + for b in range(N): + mask = b_ids == b # mask of each batch + match = int(mask.sum()) + correct_0to1_f_mask, w_pt0_i_mask = warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['depth0'][[b],...], + data['depth1'][[b],...], data['T_0to1'][[b],...], + data['K0'][[b],...], data['K1'][[b],...]) # [k, WW], [k, WW, 2] + correct_0to1_f[mask] = correct_0to1_f_mask.reshape(match, WW) + w_pt0_i[mask] = w_pt0_i_mask.reshape(match, WW, 2) + + # 4. calculate the gt index of pixel-level refinement + delta_w_pt0_i = w_pt0_i - pt1_i[b_ids, j_ids][:,None,:] # [m, WW, 2] + del b_ids, i_ids, j_ids + delta_w_pt0_f = delta_w_pt0_i / scalei1[:,None,:] + W // 2 - 0.5 + delta_w_pt0_f_round = delta_w_pt0_f[:, :, :].round() + if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT: + # calculate the overlap area between warped patch and grid patch as the loss weight. + w_pt0_f_error = (1.0 - 2*torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0, 1] + delta_w_pt0_f_round = delta_w_pt0_f_round.long() + + nearest_index1 = delta_w_pt0_f_round[..., 0] + delta_w_pt0_f_round[..., 1] * W # [m, WW] + + # corner case: out of fine windows + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + ob_mask = out_bound_mask(delta_w_pt0_f_round, W, W) + nearest_index1[ob_mask] = 0 + correct_0to1_f[ob_mask] = 0 + + m_ids, i_ids = torch.where(correct_0to1_f != 0) + j_ids = nearest_index1[m_ids, i_ids] # i_ids, j_ids range from [0, WW-1] + j_ids_di, j_ids_dj = j_ids // W, j_ids % W # further get the (i, j) index in fine windows of image1 (right image); j_ids_di, j_ids_dj range from [0, W-1] + m_ids, i_ids, j_ids_di, j_ids_dj = m_ids.to(torch.long), i_ids.to(torch.long), j_ids_di.to(torch.long), j_ids_dj.to(torch.long) + + # expec_f_gt will be used as the gt of subpixel-level refinement + expec_f_gt = delta_w_pt0_f - delta_w_pt0_f_round + + if m_ids.numel() == 0: # special case: there is no pixel-level gt + loguru_logger.warning(f"No groundtruth fine match found for local regress: {data['pair_names']}") + # this won't affect fine-level loss calculation + data.update({'expec_f': torch.zeros(1, 2, device=device)}) + data.update({'expec_f_gt': torch.zeros(1, 2, device=device)}) + else: + expec_f_gt = expec_f_gt[m_ids, i_ids] + data.update({"expec_f_gt": expec_f_gt}) + data.update({"m_ids_f": m_ids, + "i_ids_f": i_ids, + "j_ids_f_di": j_ids_di, + "j_ids_f_dj": j_ids_dj + }) + + # 5. construct a pixel-level gt conf_matrix + conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device, dtype=torch.bool) + conf_matrix_f_gt[m_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_f_gt': conf_matrix_f_gt}) + if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT: + # calculate the overlap area between warped pixel and grid pixel as the loss weight. + w_pt0_f_error = w_pt0_f_error[m_ids, i_ids] + data.update({'conf_matrix_f_error_gt': w_pt0_f_error}) + + if conf_matrix_f_gt.sum() == 0: + loguru_logger.info(f'no fine matches to supervise') + +def compute_supervision_fine(data, config, logger=None): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config, logger) + else: + raise NotImplementedError \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/losses/loftr_loss.py b/imcui/third_party/EfficientLoFTR/src/losses/loftr_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..eea71e59c1b43111bfb0d24f704df1a90bb66a03 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/losses/loftr_loss.py @@ -0,0 +1,229 @@ +from loguru import logger + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class LoFTRLoss(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config # config under the global namespace + + self.loss_config = config['loftr']['loss'] + self.match_type = 'dual_softmax' + self.sparse_spvs = self.config['loftr']['match_coarse']['sparse_spvs'] + self.fine_sparse_spvs = self.config['loftr']['match_fine']['sparse_spvs'] + + # coarse-level + self.correct_thr = self.loss_config['fine_correct_thr'] + self.c_pos_w = self.loss_config['pos_weight'] + self.c_neg_w = self.loss_config['neg_weight'] + # coarse_overlap_weight + self.overlap_weightc = self.config['loftr']['loss']['coarse_overlap_weight'] + self.overlap_weightf = self.config['loftr']['loss']['fine_overlap_weight'] + # subpixel-level + self.local_regressw = self.config['loftr']['fine_window_size'] + self.local_regress_temperature = self.config['loftr']['match_fine']['local_regress_temperature'] + + + def compute_coarse_loss(self, conf, conf_gt, weight=None, overlap_weight=None): + """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt. + Args: + conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1) + conf_gt (torch.Tensor): (N, HW0, HW1) + weight (torch.Tensor): (N, HW0, HW1) + """ + pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 + del conf_gt + # logger.info(f'real sum of conf_matrix_c_gt: {pos_mask.sum().item()}') + c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w + # corner case: no gt coarse-level match at all + if not pos_mask.any(): # assign a wrong gt + pos_mask[0, 0, 0] = True + if weight is not None: + weight[0, 0, 0] = 0. + c_pos_w = 0. + if not neg_mask.any(): + neg_mask[0, 0, 0] = True + if weight is not None: + weight[0, 0, 0] = 0. + c_neg_w = 0. + + if self.loss_config['coarse_type'] == 'focal': + conf = torch.clamp(conf, 1e-6, 1-1e-6) + alpha = self.loss_config['focal_alpha'] + gamma = self.loss_config['focal_gamma'] + + if self.sparse_spvs: + pos_conf = conf[pos_mask] + loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() + # handle loss weights + if weight is not None: + # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out, + # but only through manually setting corresponding regions in sim_matrix to '-inf'. + loss_pos = loss_pos * weight[pos_mask] + if self.overlap_weightc: + loss_pos = loss_pos * overlap_weight # already been masked slice in supervision + + loss = c_pos_w * loss_pos.mean() + return loss + else: # dense supervision + loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() + loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() + logger.info("conf_pos_c: {loss_pos}, conf_neg_c: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean())) + if weight is not None: + loss_pos = loss_pos * weight[pos_mask] + loss_neg = loss_neg * weight[neg_mask] + if self.overlap_weightc: + loss_pos = loss_pos * overlap_weight # already been masked slice in supervision + + loss_pos_mean, loss_neg_mean = loss_pos.mean(), loss_neg.mean() + logger.info("conf_pos_c: {loss_pos}, conf_neg_c: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean())) + return c_pos_w * loss_pos_mean + c_neg_w * loss_neg_mean + # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed + else: + raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type'])) + + def compute_fine_loss(self, conf_matrix_f, conf_matrix_f_gt, overlap_weight=None): + """ + Args: + conf_matrix_f (torch.Tensor): [m, WW, WW] + conf_matrix_f_gt (torch.Tensor): [m, WW, WW] + """ + if conf_matrix_f_gt.shape[0] == 0: + if self.training: # this seldomly happen during training, since we pad prediction with gt + # sometimes there is not coarse-level gt at all. + logger.warning("assign a false supervision to avoid ddp deadlock") + pass + else: + return None + pos_mask, neg_mask = conf_matrix_f_gt == 1, conf_matrix_f_gt == 0 + del conf_matrix_f_gt + c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w + + if not pos_mask.any(): # assign a wrong gt + pos_mask[0, 0, 0] = True + c_pos_w = 0. + if not neg_mask.any(): + neg_mask[0, 0, 0] = True + c_neg_w = 0. + + conf = torch.clamp(conf_matrix_f, 1e-6, 1-1e-6) + alpha = self.loss_config['focal_alpha'] + gamma = self.loss_config['focal_gamma'] + + if self.fine_sparse_spvs: + loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() + if self.overlap_weightf: + loss_pos = loss_pos * overlap_weight # already been masked slice in supervision + return c_pos_w * loss_pos.mean() + else: + loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log() + loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log() + logger.info("conf_pos_f: {loss_pos}, conf_neg_f: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean())) + if self.overlap_weightf: + loss_pos = loss_pos * overlap_weight # already been masked slice in supervision + + return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() + + + def _compute_local_loss_l2(self, expec_f, expec_f_gt): + """ + Args: + expec_f (torch.Tensor): [M, 2] + expec_f_gt (torch.Tensor): [M, 2] + """ + correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + if correct_mask.sum() == 0: + if self.training: # this seldomly happen when training, since we pad prediction with gt + logger.warning("assign a false supervision to avoid ddp deadlock") + correct_mask[0] = True + else: + return None + offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1) + return offset_l2.mean() + + @torch.no_grad() + def compute_c_weight(self, data): + """ compute element-wise weights for computing coarse-level loss. """ + if 'mask0' in data: + c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]) + else: + c_weight = None + return c_weight + + def forward(self, data): + """ + Update: + data (dict): update{ + 'loss': [1] the reduced loss across a batch, + 'loss_scalars' (dict): loss scalars for tensorboard_record + } + """ + loss_scalars = {} + # 0. compute element-wise loss weight + c_weight = self.compute_c_weight(data) + + # 1. coarse-level loss + if self.overlap_weightc: + loss_c = self.compute_coarse_loss( + data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \ + else data['conf_matrix'], + data['conf_matrix_gt'], + weight=c_weight, overlap_weight=data['conf_matrix_error_gt']) + + else: + loss_c = self.compute_coarse_loss( + data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \ + else data['conf_matrix'], + data['conf_matrix_gt'], + weight=c_weight) + + loss = loss_c * self.loss_config['coarse_weight'] + loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) + + # 2. pixel-level loss (first-stage refinement) + if self.overlap_weightf: + loss_f = self.compute_fine_loss(data['conf_matrix_f'], data['conf_matrix_f_gt'], data['conf_matrix_f_error_gt']) + else: + loss_f = self.compute_fine_loss(data['conf_matrix_f'], data['conf_matrix_f_gt']) + if loss_f is not None: + loss += loss_f * self.loss_config['fine_weight'] + loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) + else: + assert self.training is False + loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound + + # 3. subpixel-level loss (second-stage refinement) + # we calculate subpixel-level loss for all pixel-level gt + if 'expec_f' not in data: + sim_matrix_f, m_ids, i_ids, j_ids_di, j_ids_dj = data['sim_matrix_ff'], data['m_ids_f'], data['i_ids_f'], data['j_ids_f_di'], data['j_ids_f_dj'] + del data['sim_matrix_ff'], data['m_ids_f'], data['i_ids_f'], data['j_ids_f_di'], data['j_ids_f_dj'] + delta = create_meshgrid(3, 3, True, sim_matrix_f.device).to(torch.long) # [1, 3, 3, 2] + m_ids = m_ids[...,None,None].expand(-1, 3, 3) + i_ids = i_ids[...,None,None].expand(-1, 3, 3) + # Note that j_ids_di & j_ids_dj in (i, j) format while delta in (x, y) format + j_ids_di = j_ids_di[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1] + j_ids_dj = j_ids_dj[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0] + + sim_matrix_f = sim_matrix_f.reshape(-1, self.local_regressw*self.local_regressw, self.local_regressw+2, self.local_regressw+2) # [M, WW, W+2, W+2] + sim_matrix_f = sim_matrix_f[m_ids, i_ids, j_ids_di, j_ids_dj] + sim_matrix_f = sim_matrix_f.reshape(-1, 9) + + sim_matrix_f = F.softmax(sim_matrix_f / self.local_regress_temperature, dim=-1) + heatmap = sim_matrix_f.reshape(-1, 3, 3) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + data.update({'expec_f': coords_normalized}) + loss_l = self._compute_local_loss_l2(data['expec_f'], data['expec_f_gt']) + + loss += loss_l * self.loss_config['local_weight'] + loss_scalars.update({"loss_l": loss_l.clone().detach().cpu()}) + + loss_scalars.update({'loss': loss.clone().detach().cpu()}) + data.update({"loss": loss, "loss_scalars": loss_scalars}) \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/optimizers/__init__.py b/imcui/third_party/EfficientLoFTR/src/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/optimizers/__init__.py @@ -0,0 +1,42 @@ +import torch +from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR + + +def build_optimizer(model, config): + name = config.TRAINER.OPTIMIZER + lr = config.TRAINER.TRUE_LR + + if name == "adam": + return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) + elif name == "adamw": + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) + else: + raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") + + +def build_scheduler(config, optimizer): + """ + Returns: + scheduler (dict):{ + 'scheduler': lr_scheduler, + 'interval': 'step', # or 'epoch' + 'monitor': 'val_f1', (optional) + 'frequency': x, (optional) + } + """ + scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + name = config.TRAINER.SCHEDULER + + if name == 'MultiStepLR': + scheduler.update( + {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) + elif name == 'CosineAnnealing': + scheduler.update( + {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) + elif name == 'ExponentialLR': + scheduler.update( + {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + else: + raise NotImplementedError() + + return scheduler diff --git a/imcui/third_party/EfficientLoFTR/src/utils/augment.py b/imcui/third_party/EfficientLoFTR/src/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/augment.py @@ -0,0 +1,55 @@ +import albumentations as A + + +class DarkAug(object): + """ + Extreme dark augmentation aiming at Aachen Day-Night + """ + + def __init__(self) -> None: + self.augmentor = A.Compose([ + A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) + ], p=0.75) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +class MobileAug(object): + """ + Random augmentations aiming at images of mobile/handhold devices. + """ + + def __init__(self): + self.augmentor = A.Compose([ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25) + ], p=1.0) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +def build_augmentor(method=None, **kwargs): + if method is not None: + raise NotImplementedError('Using of augmentation functions are not supported yet!') + if method == 'dark': + return DarkAug() + elif method == 'mobile': + return MobileAug() + elif method is None: + return None + else: + raise ValueError(f'Invalid augmentation method: {method}') + + +if __name__ == '__main__': + augmentor = build_augmentor('FDA') diff --git a/imcui/third_party/EfficientLoFTR/src/utils/comm.py b/imcui/third_party/EfficientLoFTR/src/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/comm.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +[Copied from detectron2] +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/imcui/third_party/EfficientLoFTR/src/utils/dataloader.py b/imcui/third_party/EfficientLoFTR/src/utils/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/dataloader.py @@ -0,0 +1,23 @@ +import numpy as np + + +# --- PL-DATAMODULE --- + +def get_local_split(items: list, world_size: int, rank: int, seed: int): + """ The local rank only loads a split of the dataset. """ + n_items = len(items) + items_permute = np.random.RandomState(seed).permutation(items) + if n_items % world_size == 0: + padded_items = items_permute + else: + padding = np.random.RandomState(seed).choice( + items, + world_size - (n_items % world_size), + replace=True) + padded_items = np.concatenate([items_permute, padding]) + assert len(padded_items) % world_size == 0, \ + f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' + n_per_rank = len(padded_items) // world_size + local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] + + return local_items diff --git a/imcui/third_party/EfficientLoFTR/src/utils/dataset.py b/imcui/third_party/EfficientLoFTR/src/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..37831292d08b5e9f13eeb0dee64ae8882f52a63f --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/dataset.py @@ -0,0 +1,186 @@ +import io +from loguru import logger + +import cv2 +import numpy as np +import h5py +import torch +from numpy.linalg import inv + + +try: + # for internel use only + from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT +except Exception: + MEGADEPTH_CLIENT = SCANNET_CLIENT = None + +# --- DATA IO --- + +def load_array_from_s3( + path, client, cv_type, + use_h5py=False, +): + byte_str = client.Get(path) + try: + if not use_h5py: + raw_array = np.fromstring(byte_str, np.uint8) + data = cv2.imdecode(raw_array, cv_type) + else: + f = io.BytesIO(byte_str) + data = np.array(h5py.File(f, 'r')['/depth']) + except Exception as ex: + print(f"==> Data loading failure: {path}") + raise ex + + assert data is not None + return data + + +def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ + else cv2.IMREAD_COLOR + if str(path).startswith('s3://'): + image = load_array_from_s3(str(path), client, cv_type) + else: + image = cv2.imread(str(path), cv_type) + + if augment_fn is not None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (h, w) + + +def get_resized_wh(w, h, resize=None): + if resize is not None: # resize the longer edge + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new, h_new = map(lambda x: int(x // df * df), [w, h]) + else: + w_new, h_new = w, h + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + elif inp.ndim == 3: + padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) + padded[:, :inp.shape[1], :inp.shape[2]] = inp + if ret_mask: + mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) + mask[:, :inp.shape[1], :inp.shape[2]] = True + else: + raise NotImplementedError() + return padded, mask + + +# --- MEGADEPTH --- + +def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) + + # resize image + w, h = image.shape[1], image.shape[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + if mask is not None: + mask = torch.from_numpy(mask) + + return image, mask, scale + + +def read_megadepth_depth(path, pad_to=None): + if str(path).startswith('s3://'): + depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) + else: + depth = np.array(h5py.File(path, 'r')['depth']) + if pad_to is not None: + depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +# --- ScanNet --- + +def read_scannet_gray(path, resize=(640, 480), augment_fn=None): + """ + Args: + resize (tuple): align image to depthmap, in (w, h). + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read and resize image + image = imread_gray(path, augment_fn) + image = cv2.resize(image, resize) + + # (h, w) -> (1, h, w) and normalized + image = torch.from_numpy(image).float()[None] / 255 + return image + + +def read_scannet_depth(path): + if str(path).startswith('s3://'): + depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) + else: + depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +def read_scannet_pose(path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = inv(cam2world) + return world2cam + + +def read_scannet_intrinsic(path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return intrinsic[:-1, :-1] diff --git a/imcui/third_party/EfficientLoFTR/src/utils/metrics.py b/imcui/third_party/EfficientLoFTR/src/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc37896f363567d1f91d630def7bb717569a4ed --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/metrics.py @@ -0,0 +1,264 @@ +import torch +import cv2 +import numpy as np +from collections import OrderedDict +from loguru import logger +from kornia.geometry.epipolar import numeric +from kornia.geometry.conversions import convert_points_to_homogeneous +import pprint + + +# --- METRICS --- + +def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): + # angle error between 2 vectors + t_gt = T_0to1[:3, 3] + n = np.linalg.norm(t) * np.linalg.norm(t_gt) + t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) + t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity + if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging + t_err = 0 + + # angle error between 2 rotation matrices + R_gt = T_0to1[:3, :3] + cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 + cos = np.clip(cos, -1., 1.) # handle numercial errors + R_err = np.rad2deg(np.abs(np.arccos(cos))) + + return t_err, R_err + + +def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (torch.Tensor): [N, 2] + E (torch.Tensor): [3, 3] + """ + pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + pts0 = convert_points_to_homogeneous(pts0) + pts1 = convert_points_to_homogeneous(pts1) + + Ep0 = pts0 @ E.T # [N, 3] + p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] + Etp1 = pts1 @ E # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + + +def compute_symmetrical_epipolar_errors(data): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) + E_mat = Tx @ data['T_0to1'][:, :3, :3] + + m_bids = data['m_bids'] + pts0 = data['mkpts0_f'] + pts1 = data['mkpts1_f'] + + epi_errs = [] + for bs in range(Tx.size(0)): + mask = m_bids == bs + epi_errs.append( + symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + epi_errs = torch.cat(epi_errs, dim=0) + + data.update({'epi_errs': epi_errs}) + + +def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + if len(kpts0) < 5: + return None + # normalize keypoints + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + # normalize ransac threshold + ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) + + # compute pose with cv2 + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) + if E is None: + print("\nE is None while trying to recover pose.\n") + return None + + # recover pose from E + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + ret = (R, t[:, 0], mask.ravel() > 0) + best_num_inliers = n + + return ret + + +def estimate_lo_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + from .warppers import Camera, Pose + import poselib + camera0, camera1 = Camera.from_calibration_matrix(K0).float(), Camera.from_calibration_matrix(K1).float() + pts0, pts1 = kpts0, kpts1 + + M, info = poselib.estimate_relative_pose( + pts0, + pts1, + camera0.to_cameradict(), + camera1.to_cameradict(), + { + "max_epipolar_error": thresh, + }, + ) + success = M is not None and ( ((M.t != [0., 0., 0.]).all()) or ((M.q != [1., 0., 0., 0.]).all()) ) + if success: + M = Pose.from_Rt(torch.tensor(M.R), torch.tensor(M.t)) # .to(pts0) + # print(M) + else: + M = Pose.from_4x4mat(torch.eye(4).numpy()) # .to(pts0) + # print(M) + + estimation = { + "success": success, + "M_0to1": M, + "inliers": torch.tensor(info.pop("inliers")), # .to(pts0), + **info, + } + return estimation + + +def compute_pose_errors(data, config): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] + "t_errs" List[float]: [N] + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 + conf = config.TRAINER.RANSAC_CONF # 0.99999 + RANSAC = config.TRAINER.POSE_ESTIMATION_METHOD + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + K0 = data['K0'].cpu().numpy() + K1 = data['K1'].cpu().numpy() + T_0to1 = data['T_0to1'].cpu().numpy() + + for bs in range(K0.shape[0]): + mask = m_bids == bs + if config.LOFTR.EVAL_TIMES >= 1: + bpts0, bpts1 = pts0[mask], pts1[mask] + R_list, T_list, inliers_list = [], [], [] + # for _ in range(config.LOFTR.EVAL_TIMES): + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(bpts0))) + if _ >= config.LOFTR.EVAL_TIMES: + continue + bpts0 = bpts0[shuffling] + bpts1 = bpts1[shuffling] + + if RANSAC == 'RANSAC': + ret = estimate_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf) + if ret is None: + R_list.append(np.inf) + T_list.append(np.inf) + inliers_list.append(np.array([]).astype(bool)) + else: + R, t, inliers = ret + t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) + R_list.append(R_err) + T_list.append(t_err) + inliers_list.append(inliers) + + elif RANSAC == 'LO-RANSAC': + est = estimate_lo_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf) + if not est["success"]: + R_list.append(90) + T_list.append(90) + inliers_list.append(np.array([]).astype(bool)) + else: + M = est["M_0to1"] + inl = est["inliers"].numpy() + t_error, r_error = relative_pose_error(T_0to1[bs], M.R, M.t, ignore_gt_t_thr=0.0) + R_list.append(r_error) + T_list.append(t_error) + inliers_list.append(inl) + else: + raise ValueError(f"Unknown RANSAC method: {RANSAC}") + + data['R_errs'].append(R_list) + data['t_errs'].append(T_list) + data['inliers'].append(inliers_list[0]) + + +# --- METRIC AGGREGATION --- + +def error_auc(errors, thresholds): + """ + Args: + errors (list): [N,] + thresholds (list) + """ + errors = [0] + sorted(list(errors)) + recall = list(np.linspace(0, 1, len(errors))) + + aucs = [] + thresholds = [5, 10, 20] + for thr in thresholds: + last_index = np.searchsorted(errors, thr) + y = recall[:last_index] + [recall[last_index-1]] + x = errors[:last_index] + [thr] + aucs.append(np.trapz(y, x) / thr) + + return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + + +def epidist_prec(errors, thresholds, ret_dict=False): + precs = [] + for thr in thresholds: + prec_ = [] + for errs in errors: + correct_mask = errs < thr + prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) + precs.append(np.mean(prec_) if len(prec_) > 0 else 0) + if ret_dict: + return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + else: + return precs + + +def aggregate_metrics(metrics, epi_err_thr=5e-4, config=None): + """ Aggregate metrics for the whole dataset: + (This method should be called once per dataset) + 1. AUC of the pose error (angular) at the threshold [5, 10, 20] + 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) + """ + # filter duplicates + unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) + unq_ids = list(unq_ids.values()) + logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') + + # pose auc + angular_thresholds = [5, 10, 20] + + if config.LOFTR.EVAL_TIMES >= 1: + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0).reshape(-1, config.LOFTR.EVAL_TIMES)[unq_ids].reshape(-1) + else: + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) + + # matching precision + dist_thresholds = [epi_err_thr] + precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) + + u_num_mathces = np.array(metrics['num_matches'], dtype=object)[unq_ids] + num_matches = {f'num_matches': u_num_mathces.mean() } + return {**aucs, **precs, **num_matches} diff --git a/imcui/third_party/EfficientLoFTR/src/utils/misc.py b/imcui/third_party/EfficientLoFTR/src/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..604b9bc93e0fb92a9750fe72f3d692edc84207b5 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/misc.py @@ -0,0 +1,106 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() + +def detect_NaN(feat_0, feat_1): + logger.info(f'NaN detected in feature') + logger.info(f"#NaN in feat_0: {torch.isnan(feat_0).int().sum()}, #NaN in feat_1: {torch.isnan(feat_1).int().sum()}") + feat_0[torch.isnan(feat_0)] = 0 + feat_1[torch.isnan(feat_1)] = 0 diff --git a/imcui/third_party/EfficientLoFTR/src/utils/plotting.py b/imcui/third_party/EfficientLoFTR/src/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4260c7487cfbc76dda94c589957601cea972d4 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/plotting.py @@ -0,0 +1,154 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth': + thr = 1e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic'): + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_confidence_figure(data, b_id): + # TODO: Implement confidence figure + raise NotImplementedError() + +def make_matching_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence', 'gt'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/utils/profiler.py b/imcui/third_party/EfficientLoFTR/src/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/profiler.py @@ -0,0 +1,39 @@ +import torch +from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler +from contextlib import contextmanager +from pytorch_lightning.utilities import rank_zero_only + + +class InferenceProfiler(SimpleProfiler): + """ + This profiler records duration of actions with cuda.synchronize() + Use this in test time. + """ + + def __init__(self): + super().__init__() + self.start = rank_zero_only(self.start) + self.stop = rank_zero_only(self.stop) + self.summary = rank_zero_only(self.summary) + + @contextmanager + def profile(self, action_name: str) -> None: + try: + torch.cuda.synchronize() + self.start(action_name) + yield action_name + finally: + torch.cuda.synchronize() + self.stop(action_name) + + +def build_profiler(name): + if name == 'inference': + return InferenceProfiler() + elif name == 'pytorch': + from pytorch_lightning.profiler import PyTorchProfiler + return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) + elif name is None: + return PassThroughProfiler() + else: + raise ValueError(f'Invalid profiler: {name}') diff --git a/imcui/third_party/EfficientLoFTR/src/utils/warppers.py b/imcui/third_party/EfficientLoFTR/src/utils/warppers.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f33b78f0d1645b3eefc8c9c6dbe14f24f7b2d6 --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/warppers.py @@ -0,0 +1,426 @@ +""" +Convenience classes for an SE3 pose and a pinhole Camera with lens distortion. +Based on PyTorch tensors: differentiable, batched, with GPU support. +Modified from: https://github.com/cvg/glue-factory/blob/scannet1500/gluefactory/geometry/wrappers.py +""" + +import functools +import inspect +import math +from typing import Dict, List, NamedTuple, Optional, Tuple, Union + +import numpy as np +import torch + +from .warppers_utils import ( + J_distort_points, + distort_points, + skew_symmetric, + so3exp_map, + to_homogeneous, +) + + +def autocast(func): + """Cast the inputs of a TensorWrapper method to PyTorch tensors + if they are numpy arrays. Use the device and dtype of the wrapper. + """ + + @functools.wraps(func) + def wrap(self, *args): + device = torch.device("cpu") + dtype = None + if isinstance(self, TensorWrapper): + if self._data is not None: + device = self.device + dtype = self.dtype + elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): + raise ValueError(self) + + cast_args = [] + for arg in args: + if isinstance(arg, np.ndarray): + arg = torch.from_numpy(arg) + arg = arg.to(device=device, dtype=dtype) + cast_args.append(arg) + return func(self, *cast_args) + + return wrap + + +class TensorWrapper: + _data = None + + @autocast + def __init__(self, data: torch.Tensor): + self._data = data + + @property + def shape(self): + return self._data.shape[:-1] + + @property + def device(self): + return self._data.device + + @property + def dtype(self): + return self._data.dtype + + def __getitem__(self, index): + return self.__class__(self._data[index]) + + def __setitem__(self, index, item): + self._data[index] = item.data + + def to(self, *args, **kwargs): + return self.__class__(self._data.to(*args, **kwargs)) + + def cpu(self): + return self.__class__(self._data.cpu()) + + def cuda(self): + return self.__class__(self._data.cuda()) + + def pin_memory(self): + return self.__class__(self._data.pin_memory()) + + def float(self): + return self.__class__(self._data.float()) + + def double(self): + return self.__class__(self._data.double()) + + def detach(self): + return self.__class__(self._data.detach()) + + @classmethod + def stack(cls, objects: List, dim=0, *, out=None): + data = torch.stack([obj._data for obj in objects], dim=dim, out=out) + return cls(data) + + @classmethod + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.stack: + return self.stack(*args, **kwargs) + else: + return NotImplemented + + +class Pose(TensorWrapper): + def __init__(self, data: torch.Tensor): + assert data.shape[-1] == 12 + super().__init__(data) + + @classmethod + @autocast + def from_Rt(cls, R: torch.Tensor, t: torch.Tensor): + """Pose from a rotation matrix and translation vector. + Accepts numpy arrays or PyTorch tensors. + + Args: + R: rotation matrix with shape (..., 3, 3). + t: translation vector with shape (..., 3). + """ + assert R.shape[-2:] == (3, 3) + assert t.shape[-1] == 3 + assert R.shape[:-2] == t.shape[:-1] + data = torch.cat([R.flatten(start_dim=-2), t], -1) + return cls(data) + + @classmethod + @autocast + def from_aa(cls, aa: torch.Tensor, t: torch.Tensor): + """Pose from an axis-angle rotation vector and translation vector. + Accepts numpy arrays or PyTorch tensors. + + Args: + aa: axis-angle rotation vector with shape (..., 3). + t: translation vector with shape (..., 3). + """ + assert aa.shape[-1] == 3 + assert t.shape[-1] == 3 + assert aa.shape[:-1] == t.shape[:-1] + return cls.from_Rt(so3exp_map(aa), t) + + @classmethod + def from_4x4mat(cls, T: torch.Tensor): + """Pose from an SE(3) transformation matrix. + Args: + T: transformation matrix with shape (..., 4, 4). + """ + assert T.shape[-2:] == (4, 4) + R, t = T[..., :3, :3], T[..., :3, 3] + return cls.from_Rt(R, t) + + @classmethod + def from_colmap(cls, image: NamedTuple): + """Pose from a COLMAP Image.""" + return cls.from_Rt(image.qvec2rotmat(), image.tvec) + + @property + def R(self) -> torch.Tensor: + """Underlying rotation matrix with shape (..., 3, 3).""" + rvec = self._data[..., :9] + return rvec.reshape(rvec.shape[:-1] + (3, 3)) + + @property + def t(self) -> torch.Tensor: + """Underlying translation vector with shape (..., 3).""" + return self._data[..., -3:] + + def inv(self) -> "Pose": + """Invert an SE(3) pose.""" + R = self.R.transpose(-1, -2) + t = -(R @ self.t.unsqueeze(-1)).squeeze(-1) + return self.__class__.from_Rt(R, t) + + def compose(self, other: "Pose") -> "Pose": + """Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C.""" + R = self.R @ other.R + t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1) + return self.__class__.from_Rt(R, t) + + @autocast + def transform(self, p3d: torch.Tensor) -> torch.Tensor: + """Transform a set of 3D points. + Args: + p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3). + """ + assert p3d.shape[-1] == 3 + # assert p3d.shape[:-2] == self.shape # allow broadcasting + return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2) + + def __mul__(self, p3D: torch.Tensor) -> torch.Tensor: + """Transform a set of 3D points: T_A2B * p3D_A -> p3D_B.""" + return self.transform(p3D) + + def __matmul__( + self, other: Union["Pose", torch.Tensor] + ) -> Union["Pose", torch.Tensor]: + """Transform a set of 3D points: T_A2B * p3D_A -> p3D_B. + or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C.""" + if isinstance(other, self.__class__): + return self.compose(other) + else: + return self.transform(other) + + @autocast + def J_transform(self, p3d_out: torch.Tensor): + # [[1,0,0,0,-pz,py], + # [0,1,0,pz,0,-px], + # [0,0,1,-py,px,0]] + J_t = torch.diag_embed(torch.ones_like(p3d_out)) + J_rot = -skew_symmetric(p3d_out) + J = torch.cat([J_t, J_rot], dim=-1) + return J # N x 3 x 6 + + def numpy(self) -> Tuple[np.ndarray]: + return self.R.numpy(), self.t.numpy() + + def magnitude(self) -> Tuple[torch.Tensor]: + """Magnitude of the SE(3) transformation. + Returns: + dr: rotation anngle in degrees. + dt: translation distance in meters. + """ + trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1) + cos = torch.clamp((trace - 1) / 2, -1, 1) + dr = torch.acos(cos).abs() / math.pi * 180 + dt = torch.norm(self.t, dim=-1) + return dr, dt + + def __repr__(self): + return f"Pose: {self.shape} {self.dtype} {self.device}" + + +class Camera(TensorWrapper): + eps = 1e-4 + + def __init__(self, data: torch.Tensor): + assert data.shape[-1] in {6, 8, 10} + super().__init__(data) + + @classmethod + def from_colmap(cls, camera: Union[Dict, NamedTuple]): + """Camera from a COLMAP Camera tuple or dictionary. + We use the corner-convetion from COLMAP (center of top left pixel is (0.5, 0.5)) + """ + if isinstance(camera, tuple): + camera = camera._asdict() + + model = camera["model"] + params = camera["params"] + + if model in ["OPENCV", "PINHOLE", "RADIAL"]: + (fx, fy, cx, cy), params = np.split(params, [4]) + elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]: + (f, cx, cy), params = np.split(params, [3]) + fx = fy = f + if model == "SIMPLE_RADIAL": + params = np.r_[params, 0.0] + else: + raise NotImplementedError(model) + + data = np.r_[camera["width"], camera["height"], fx, fy, cx, cy, params] + return cls(data) + + @classmethod + @autocast + def from_calibration_matrix(cls, K: torch.Tensor): + cx, cy = K[..., 0, 2], K[..., 1, 2] + fx, fy = K[..., 0, 0], K[..., 1, 1] + data = torch.stack([2 * cx, 2 * cy, fx, fy, cx, cy], -1) + return cls(data) + + @autocast + def calibration_matrix(self): + K = torch.zeros( + *self._data.shape[:-1], + 3, + 3, + device=self._data.device, + dtype=self._data.dtype, + ) + K[..., 0, 2] = self._data[..., 4] + K[..., 1, 2] = self._data[..., 5] + K[..., 0, 0] = self._data[..., 2] + K[..., 1, 1] = self._data[..., 3] + K[..., 2, 2] = 1.0 + return K + + @property + def size(self) -> torch.Tensor: + """Size (width height) of the images, with shape (..., 2).""" + return self._data[..., :2] + + @property + def f(self) -> torch.Tensor: + """Focal lengths (fx, fy) with shape (..., 2).""" + return self._data[..., 2:4] + + @property + def c(self) -> torch.Tensor: + """Principal points (cx, cy) with shape (..., 2).""" + return self._data[..., 4:6] + + @property + def dist(self) -> torch.Tensor: + """Distortion parameters, with shape (..., {0, 2, 4}).""" + return self._data[..., 6:] + + @autocast + def scale(self, scales: torch.Tensor): + """Update the camera parameters after resizing an image.""" + s = scales + data = torch.cat([self.size * s, self.f * s, self.c * s, self.dist], -1) + return self.__class__(data) + + def crop(self, left_top: Tuple[float], size: Tuple[int]): + """Update the camera parameters after cropping an image.""" + left_top = self._data.new_tensor(left_top) + size = self._data.new_tensor(size) + data = torch.cat([size, self.f, self.c - left_top, self.dist], -1) + return self.__class__(data) + + @autocast + def in_image(self, p2d: torch.Tensor): + """Check if 2D points are within the image boundaries.""" + assert p2d.shape[-1] == 2 + # assert p2d.shape[:-2] == self.shape # allow broadcasting + size = self.size.unsqueeze(-2) + valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1) + return valid + + @autocast + def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Project 3D points into the camera plane and check for visibility.""" + z = p3d[..., -1] + valid = z > self.eps + z = z.clamp(min=self.eps) + p2d = p3d[..., :-1] / z.unsqueeze(-1) + return p2d, valid + + def J_project(self, p3d: torch.Tensor): + x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2] + zero = torch.zeros_like(z) + z = z.clamp(min=self.eps) + J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1) + J = J.reshape(p3d.shape[:-1] + (2, 3)) + return J # N x 2 x 3 + + @autocast + def distort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates + and check for validity of the distortion model. + """ + assert pts.shape[-1] == 2 + # assert pts.shape[:-2] == self.shape # allow broadcasting + return distort_points(pts, self.dist) + + def J_distort(self, pts: torch.Tensor): + return J_distort_points(pts, self.dist) # N x 2 x 2 + + @autocast + def denormalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert normalized 2D coordinates into pixel coordinates.""" + return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2) + + @autocast + def normalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert normalized 2D coordinates into pixel coordinates.""" + return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2) + + def J_denormalize(self): + return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2 + + @autocast + def cam2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Transform 3D points into 2D pixel coordinates.""" + p2d, visible = self.project(p3d) + p2d, mask = self.distort(p2d) + p2d = self.denormalize(p2d) + valid = visible & mask & self.in_image(p2d) + return p2d, valid + + def J_world2image(self, p3d: torch.Tensor): + p2d_dist, valid = self.project(p3d) + J = self.J_denormalize() @ self.J_distort(p2d_dist) @ self.J_project(p3d) + return J, valid + + @autocast + def image2cam(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert 2D pixel corrdinates to 3D points with z=1""" + assert self._data.shape + p2d = self.normalize(p2d) + # iterative undistortion + return to_homogeneous(p2d) + + def to_cameradict(self, camera_model: Optional[str] = None) -> List[Dict]: + data = self._data.clone() + if data.dim() == 1: + data = data.unsqueeze(0) + assert data.dim() == 2 + b, d = data.shape + if camera_model is None: + camera_model = {6: "PINHOLE", 8: "RADIAL", 10: "OPENCV"}[d] + cameras = [] + for i in range(b): + if camera_model.startswith("SIMPLE_"): + params = [x.item() for x in data[i, 3 : min(d, 7)]] + else: + params = [x.item() for x in data[i, 2:]] + cameras.append( + { + "model": camera_model, + "width": int(data[i, 0].item()), + "height": int(data[i, 1].item()), + "params": params, + } + ) + return cameras if self._data.dim() == 2 else cameras[0] + + def __repr__(self): + return f"Camera {self.shape} {self.dtype} {self.device}" \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/src/utils/warppers_utils.py b/imcui/third_party/EfficientLoFTR/src/utils/warppers_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad3ef5c05d74cd3bd46f5b3b0d8c6d331a17dfad --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/src/utils/warppers_utils.py @@ -0,0 +1,171 @@ +""" +Modified from: https://github.com/cvg/glue-factory/blob/scannet1500/gluefactory/geometry/utils.py +""" + +import numpy as np +import torch + + +def to_homogeneous(points): + """Convert N-dimensional points to homogeneous coordinates. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N). + Returns: + A torch.Tensor or numpy.ndarray with size (..., N+1). + """ + if isinstance(points, torch.Tensor): + pad = points.new_ones(points.shape[:-1] + (1,)) + return torch.cat([points, pad], dim=-1) + elif isinstance(points, np.ndarray): + pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype) + return np.concatenate([points, pad], axis=-1) + else: + raise ValueError + + +def from_homogeneous(points, eps=0.0): + """Remove the homogeneous dimension of N-dimensional points. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N+1). + eps: Epsilon value to prevent zero division. + Returns: + A torch.Tensor or numpy ndarray with size (..., N). + """ + return points[..., :-1] / (points[..., -1:] + eps) + + +def batched_eye_like(x: torch.Tensor, n: int): + """Create a batch of identity matrices. + Args: + x: a reference torch.Tensor whose batch dimension will be copied. + n: the size of each identity matrix. + Returns: + A torch.Tensor of size (B, n, n), with same dtype and device as x. + """ + return torch.eye(n).to(x)[None].repeat(len(x), 1, 1) + + +def skew_symmetric(v): + """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).""" + z = torch.zeros_like(v[..., 0]) + M = torch.stack( + [ + z, + -v[..., 2], + v[..., 1], + v[..., 2], + z, + -v[..., 0], + -v[..., 1], + v[..., 0], + z, + ], + dim=-1, + ).reshape(v.shape[:-1] + (3, 3)) + return M + + +def transform_points(T, points): + return from_homogeneous(to_homogeneous(points) @ T.transpose(-1, -2)) + + +def is_inside(pts, shape): + return (pts > 0).all(-1) & (pts < shape[:, None]).all(-1) + + +def so3exp_map(w, eps: float = 1e-7): + """Compute rotation matrices from batched twists. + Args: + w: batched 3D axis-angle vectors of size (..., 3). + Returns: + A batch of rotation matrices of size (..., 3, 3). + """ + theta = w.norm(p=2, dim=-1, keepdim=True) + small = theta < eps + div = torch.where(small, torch.ones_like(theta), theta) + W = skew_symmetric(w / div) + theta = theta[..., None] # ... x 1 x 1 + res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta)) + res = torch.where(small[..., None], W, res) # first-order Taylor approx + return torch.eye(3).to(W) + res + + +@torch.jit.script +def distort_points(pts, dist): + """Distort normalized 2D coordinates + and check for validity of the distortion model. + """ + dist = dist.unsqueeze(-2) # add point dimension + ndist = dist.shape[-1] + undist = pts + valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool) + if ndist > 0: + k1, k2 = dist[..., :2].split(1, -1) + r2 = torch.sum(pts**2, -1, keepdim=True) + radial = k1 * r2 + k2 * r2**2 + undist = undist + pts * radial + + # The distortion model is supposedly only valid within the image + # boundaries. Because of the negative radial distortion, points that + # are far outside of the boundaries might actually be mapped back + # within the image. To account for this, we discard points that are + # beyond the inflection point of the distortion model, + # e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0 + limited = ((k2 > 0) & ((9 * k1**2 - 20 * k2) > 0)) | ((k2 <= 0) & (k1 > 0)) + limit = torch.abs( + torch.where( + k2 > 0, + (torch.sqrt(9 * k1**2 - 20 * k2) - 3 * k1) / (10 * k2), + 1 / (3 * k1), + ) + ) + valid = valid & torch.squeeze(~limited | (r2 < limit), -1) + + if ndist > 2: + p12 = dist[..., 2:] + p21 = p12.flip(-1) + uv = torch.prod(pts, -1, keepdim=True) + undist = undist + 2 * p12 * uv + p21 * (r2 + 2 * pts**2) + # TODO: handle tangential boundaries + + return undist, valid + + +@torch.jit.script +def J_distort_points(pts, dist): + dist = dist.unsqueeze(-2) # add point dimension + ndist = dist.shape[-1] + + J_diag = torch.ones_like(pts) + J_cross = torch.zeros_like(pts) + if ndist > 0: + k1, k2 = dist[..., :2].split(1, -1) + r2 = torch.sum(pts**2, -1, keepdim=True) + uv = torch.prod(pts, -1, keepdim=True) + radial = k1 * r2 + k2 * r2**2 + d_radial = 2 * k1 + 4 * k2 * r2 + J_diag += radial + (pts**2) * d_radial + J_cross += uv * d_radial + + if ndist > 2: + p12 = dist[..., 2:] + p21 = p12.flip(-1) + J_diag += 2 * p12 * pts.flip(-1) + 6 * p21 * pts + J_cross += 2 * p12 * pts + 2 * p21 * pts.flip(-1) + + J = torch.diag_embed(J_diag) + torch.diag_embed(J_cross).flip(-1) + return J + + +def get_image_coords(img): + h, w = img.shape[-2:] + return ( + torch.stack( + torch.meshgrid( + torch.arange(h, dtype=torch.float32, device=img.device), + torch.arange(w, dtype=torch.float32, device=img.device), + indexing="ij", + )[::-1], + dim=0, + ).permute(1, 2, 0) + )[None] + 0.5 \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/test.py b/imcui/third_party/EfficientLoFTR/test.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a04324e9cd16c74ec3affbe17e9764d2a0002b --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/test.py @@ -0,0 +1,143 @@ +import pytorch_lightning as pl +import argparse +import pprint +from loguru import logger as loguru_logger + +from src.config.default import get_cfg_defaults +from src.utils.profiler import build_profiler + +from src.lightning.data import MultiSceneDataModule +from src.lightning.lightning_loftr import PL_LoFTR + +import torch + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') + parser.add_argument( + '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") + parser.add_argument( + '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--batch_size', type=int, default=1, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=2) + parser.add_argument( + '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') + parser.add_argument( + '--pixel_thr', type=float, default=None, help='modify the RANSAC threshold.') + parser.add_argument( + '--ransac', type=str, default=None, help='modify the RANSAC method') + parser.add_argument( + '--scannetX', type=int, default=None, help='ScanNet resize X') + parser.add_argument( + '--scannetY', type=int, default=None, help='ScanNet resize Y') + parser.add_argument( + '--megasize', type=int, default=None, help='MegaDepth resize') + parser.add_argument( + '--npe', action='store_true', default=False, help='') + parser.add_argument( + '--fp32', action='store_true', default=False, help='') + parser.add_argument( + '--ransac_times', type=int, default=None, help='repeat ransac multiple times for more robust evaluation') + parser.add_argument( + '--rmbd', type=int, default=None, help='remove border matches') + parser.add_argument( + '--deter', action='store_true', default=False, help='use deterministic mode for testing') + parser.add_argument( + '--half', action='store_true', default=False, help='pure16') + parser.add_argument( + '--flash', action='store_true', default=False, help='flash') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + +def inplace_relu(m): + classname = m.__class__.__name__ + if classname.find('ReLU') != -1: + m.inplace=True + +if __name__ == '__main__': + # parse arguments + args = parse_args() + pprint.pprint(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + if args.deter: + torch.backends.cudnn.deterministic = True + pl.seed_everything(config.TRAINER.SEED) # reproducibility + + # tune when testing + if args.thr is not None: + config.LOFTR.MATCH_COARSE.THR = args.thr + + if args.scannetX is not None and args.scannetY is not None: + config.DATASET.SCAN_IMG_RESIZEX = args.scannetX + config.DATASET.SCAN_IMG_RESIZEY = args.scannetY + if args.megasize is not None: + config.DATASET.MGDPT_IMG_RESIZE = args.megasize + + if args.npe: + if config.LOFTR.COARSE.ROPE: + assert config.DATASET.NPE_NAME is not None + if config.DATASET.NPE_NAME is not None: + if config.DATASET.NPE_NAME == 'megadepth': + config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.MGDPT_IMG_RESIZE, config.DATASET.MGDPT_IMG_RESIZE] # [832, 832, 1152, 1152] + elif config.DATASET.NPE_NAME == 'scannet': + config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.SCAN_IMG_RESIZEX, config.DATASET.SCAN_IMG_RESIZEX] # [832, 832, 640, 640] + else: + config.LOFTR.COARSE.NPE = [832, 832, 832, 832] + + if args.ransac_times is not None: + config.LOFTR.EVAL_TIMES = args.ransac_times + + if args.rmbd is not None: + config.LOFTR.MATCH_COARSE.BORDER_RM = args.rmbd + + if args.pixel_thr is not None: + config.TRAINER.RANSAC_PIXEL_THR = args.pixel_thr + + if args.ransac is not None: + config.TRAINER.POSE_ESTIMATION_METHOD = args.ransac + if args.ransac == 'LO-RANSAC' and config.TRAINER.RANSAC_PIXEL_THR == 0.5: + config.TRAINER.RANSAC_PIXEL_THR = 2.0 + + if args.fp32: + config.LOFTR.MP = False + + if args.half: + config.LOFTR.HALF = True + config.DATASET.FP16 = True + else: + config.LOFTR.HALF = False + config.DATASET.FP16 = False + + if args.flash: + config.LOFTR.COARSE.NO_FLASH = False + + loguru_logger.info(f"Args and config initialized!") + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) + loguru_logger.info(f"LoFTR-lightning initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"DataModule initialized!") + + # lightning trainer + trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) + + loguru_logger.info(f"Start testing!") + trainer.test(model, datamodule=data_module, verbose=False) \ No newline at end of file diff --git a/imcui/third_party/EfficientLoFTR/train.py b/imcui/third_party/EfficientLoFTR/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6d74512464fbf5d35cb3ee48b4683b8cb870ce6e --- /dev/null +++ b/imcui/third_party/EfficientLoFTR/train.py @@ -0,0 +1,154 @@ +import math +import argparse +import pprint +from distutils.util import strtobool +from pathlib import Path +from loguru import logger as loguru_logger + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.plugins import DDPPlugin, NativeMixedPrecisionPlugin + +from src.config.default import get_cfg_defaults +from src.utils.misc import get_rank_zero_only_logger, setup_gpus +from src.utils.profiler import build_profiler +from src.lightning.data import MultiSceneDataModule +from src.lightning.lightning_loftr import PL_LoFTR +import torch + +loguru_logger = get_rank_zero_only_logger(loguru_logger) + +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024" + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--exp_name', type=str, default='default_exp_name') + parser.add_argument( + '--batch_size', type=int, default=4, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=4) + parser.add_argument( + '--pin_memory', type=lambda x: bool(strtobool(x)), + nargs='?', default=True, help='whether loading data to pinned memory or not') + parser.add_argument( + '--ckpt_path', type=str, default=None, + help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR') + parser.add_argument( + '--disable_ckpt', action='store_true', + help='disable checkpoint saving (useful for debugging).') + parser.add_argument( + '--profiler_name', type=str, default=None, + help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--parallel_load_data', action='store_true', + help='load datasets in with multiple processes.') + parser.add_argument( + '--thr', type=float, default=0.1) + parser.add_argument( + '--train_coarse_percent', type=float, default=0.1, help='training tricks: save GPU memory') + parser.add_argument( + '--disable_mp', action='store_true', help='disable mixed-precision training') + parser.add_argument( + '--deter', action='store_true', help='use deterministic mode for training') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + +def inplace_relu(m): + classname = m.__class__.__name__ + if classname.find('ReLU') != -1: + m.inplace=True + +def main(): + # parse arguments + args = parse_args() + rank_zero_only(pprint.pprint)(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + get_cfg_default = get_cfg_defaults + + config = get_cfg_default() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + + if config.LOFTR.COARSE.NPE is None: + config.LOFTR.COARSE.NPE = [832, 832, 832, 832] # training at 832 resolution on MegaDepth datasets + + if args.deter: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + pl.seed_everything(config.TRAINER.SEED) # reproducibility + # TODO: Use different seeds for each dataloader workers + # This is needed for data augmentation + + # scale lr and warmup-step automatically + args.gpus = _n_gpus = setup_gpus(args.gpus) + config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes + config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size + _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS + config.TRAINER.SCALING = _scaling + config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling + config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling) + + if args.thr is not None: + config.LOFTR.MATCH_COARSE.THR = args.thr + if args.disable_mp: + config.LOFTR.MP = False + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) + loguru_logger.info(f"LoFTR LightningModule initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"LoFTR DataModule initialized!") + + # TensorBoard Logger + logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False) + ckpt_dir = Path(logger.log_dir) / 'checkpoints' + + # Callbacks + # TODO: update ModelCheckpoint to monitor multiple metrics + ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max', + save_last=True, + dirpath=str(ckpt_dir), + filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}') + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks = [lr_monitor] + if not args.disable_ckpt: + callbacks.append(ckpt_callback) + + # Lightning Trainer + trainer = pl.Trainer.from_argparse_args( + args, + plugins=[DDPPlugin(find_unused_parameters=False, + num_nodes=args.num_nodes, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), NativeMixedPrecisionPlugin()], + gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, + callbacks=callbacks, + logger=logger, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, + replace_sampler_ddp=False, # use custom sampler + reload_dataloaders_every_epoch=False, # avoid repeated samples! + weights_summary='full', + profiler=profiler) + loguru_logger.info(f"Trainer initialized!") + loguru_logger.info(f"Start training!") + + trainer.fit(model, datamodule=data_module) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/imcui/third_party/GlueStick/gluestick/__init__.py b/imcui/third_party/GlueStick/gluestick/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3051821ecfb2e18f4b9b4dfb50f35064106eb57 --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/__init__.py @@ -0,0 +1,53 @@ +import collections.abc as collections +from pathlib import Path + +import torch + +GLUESTICK_ROOT = Path(__file__).parent.parent + + +def get_class(mod_name, base_path, BaseClass): + """Get the class object which inherits from BaseClass and is defined in + the module named mod_name, child of base_path. + """ + import inspect + mod_path = '{}.{}'.format(base_path, mod_name) + mod = __import__(mod_path, fromlist=['']) + classes = inspect.getmembers(mod, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == mod_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseClass)] + assert len(classes) == 1, classes + return classes[0][1] + + +def get_model(name): + from .models.base_model import BaseModel + return get_class('models.' + name, __name__, BaseModel) + + +def numpy_image_to_torch(image): + """Normalize the image tensor and reorder the dimensions.""" + if image.ndim == 3: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + elif image.ndim == 2: + image = image[None] # add channel axis + else: + raise ValueError(f'Not an image: {image.shape}') + return torch.from_numpy(image / 255.).float() + + +def map_tensor(input_, func): + if isinstance(input_, (str, bytes)): + return input_ + elif isinstance(input_, collections.Mapping): + return {k: map_tensor(sample, func) for k, sample in input_.items()} + elif isinstance(input_, collections.Sequence): + return [map_tensor(sample, func) for sample in input_] + else: + return func(input_) + + +def batch_to_np(batch): + return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0]) diff --git a/imcui/third_party/GlueStick/gluestick/drawing.py b/imcui/third_party/GlueStick/gluestick/drawing.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6d24b6bfedc93449142647410057d978d733ef --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/drawing.py @@ -0,0 +1,166 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + + +def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, + adaptive=True): + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + adaptive: whether the figure size should fit the image aspect ratios. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + if adaptive: + ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H + else: + ratios = [4 / 3] * n + figsize = [sum(ratios) * 4.5, 4.5] + fig, ax = plt.subplots( + 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + fig.tight_layout(pad=pad) + return ax + + +def plot_keypoints(kpts, colors='lime', ps=4, alpha=1): + """Plot keypoints for existing images. + Args: + kpts: list of ndarrays of size (N, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float. + """ + if not isinstance(colors, list): + colors = [colors] * len(kpts) + axes = plt.gcf().axes + for a, k, c in zip(axes, kpts, colors): + a.scatter(k[:, 0], k[:, 1], c=c, s=ps, alpha=alpha, linewidths=0) + + +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): + """Plot matches for a pair of existing images. + Args: + kpts0, kpts1: corresponding keypoints of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + ps: size of the end points (no endpoint if ps=0) + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) + fig.lines += [matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, + alpha=a) + for i in range(len(kpts0))] + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + if ps > 0: + ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def plot_lines(lines, line_colors='orange', point_colors='cyan', + ps=4, lw=2, alpha=1., indices=(0, 1)): + """ Plot lines and endpoints for existing images. + Args: + lines: list of ndarrays of size (N, 2, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float pixels. + lw: line width as float pixels. + alpha: transparency of the points and lines. + indices: indices of the images to draw the matches on. + """ + if not isinstance(line_colors, list): + line_colors = [line_colors] * len(lines) + if not isinstance(point_colors, list): + point_colors = [point_colors] * len(lines) + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines and junctions + for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): + for i in range(len(l)): + line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]), + (l[i, 0, 1], l[i, 1, 1]), + zorder=1, c=lc, linewidth=lw, + alpha=alpha) + a.add_line(line) + pts = l.reshape(-1, 2) + a.scatter(pts[:, 0], pts[:, 1], + c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha) + + +def plot_color_line_matches(lines, correct_matches=None, + lw=2, indices=(0, 1)): + """Plot line matches for existing images with multiple colors. + Args: + lines: list of ndarrays of size (N, 2, 2). + correct_matches: bool array of size (N,) indicating correct matches. + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + n_lines = len(lines[0]) + colors = sns.color_palette('husl', n_colors=n_lines) + np.random.shuffle(colors) + alphas = np.ones(n_lines) + # If correct_matches is not None, display wrong matches with a low alpha + if correct_matches is not None: + alphas[~np.array(correct_matches)] = 0.2 + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l in zip(axes, lines): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c=colors[i], + alpha=alphas[i], linewidth=lw) for i in range(n_lines)] diff --git a/imcui/third_party/GlueStick/gluestick/geometry.py b/imcui/third_party/GlueStick/gluestick/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..97853c4807d319eb9ea0377db7385e9a72fb400b --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/geometry.py @@ -0,0 +1,175 @@ +from typing import Tuple + +import numpy as np +import torch + + +def to_homogeneous(points): + """Convert N-dimensional points to homogeneous coordinates. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N). + Returns: + A torch.Tensor or numpy.ndarray with size (..., N+1). + """ + if isinstance(points, torch.Tensor): + pad = points.new_ones(points.shape[:-1] + (1,)) + return torch.cat([points, pad], dim=-1) + elif isinstance(points, np.ndarray): + pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype) + return np.concatenate([points, pad], axis=-1) + else: + raise ValueError + + +def from_homogeneous(points, eps=0.): + """Remove the homogeneous dimension of N-dimensional points. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N+1). + Returns: + A torch.Tensor or numpy ndarray with size (..., N). + """ + return points[..., :-1] / (points[..., -1:] + eps) + + +def skew_symmetric(v): + """Create a skew-symmetric matrix from a (batched) vector of size (..., 3). + """ + z = torch.zeros_like(v[..., 0]) + M = torch.stack([ + z, -v[..., 2], v[..., 1], + v[..., 2], z, -v[..., 0], + -v[..., 1], v[..., 0], z, + ], dim=-1).reshape(v.shape[:-1] + (3, 3)) + return M + + +def T_to_E(T): + """Convert batched poses (..., 4, 4) to batched essential matrices.""" + return skew_symmetric(T[..., :3, 3]) @ T[..., :3, :3] + + +def warp_points_torch(points, H, inverse=True): + """ + Warp a list of points with the INVERSE of the given homography. + The inverse is used to be coherent with tf.contrib.image.transform + Arguments: + points: batched list of N points, shape (B, N, 2). + homography: batched or not (shapes (B, 8) and (8,) respectively). + Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warped points. + """ + # H = np.expand_dims(homography, axis=0) if len(homography.shape) == 1 else homography + + # Get the points to the homogeneous format + points = to_homogeneous(points) + + # Apply the homography + out_shape = tuple(list(H.shape[:-1]) + [3, 3]) + H_mat = torch.cat([H, torch.ones_like(H[..., :1])], axis=-1).reshape(out_shape) + if inverse: + H_mat = torch.inverse(H_mat) + warped_points = torch.einsum('...nj,...ji->...ni', points, H_mat.transpose(-2, -1)) + + warped_points = from_homogeneous(warped_points, eps=1e-5) + + return warped_points + + +def seg_equation(segs): + # calculate list of start, end and midpoints points from both lists + start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(segs[..., 1, :]) + # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1 + lines = torch.cross(start_points, end_points, dim=-1) + lines_norm = (torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]) + assert torch.all(lines_norm > 0), 'Error: trying to compute the equation of a line with a single point' + lines = lines / lines_norm + return lines + + +def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]): + h, w = img_shape + return (pts >= 0).all(dim=-1) & (pts[..., 0] < w) & (pts[..., 1] < h) & (~torch.isinf(pts).any(dim=-1)) + + +def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor: + """ + Shrink an array of segments to fit inside the image. + :param segs: The tensor of segments with shape (N, 2, 2) + :param img_shape: The image shape in format (H, W) + """ + EPS = 1e-4 + device = segs.device + w, h = img_shape[1], img_shape[0] + # Project the segments to the reference image + segs = segs.clone() + eqs = seg_equation(segs) + x0, y0 = torch.tensor([1., 0, 0.], device=device), torch.tensor([0., 1, 0], device=device) + x0 = x0.repeat(eqs.shape[:-1] + (1,)) + y0 = y0.repeat(eqs.shape[:-1] + (1,)) + pt_x0s = torch.cross(eqs, x0, dim=-1) + pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1] + pt_x0s_valid = is_inside_img(pt_x0s, img_shape) + pt_y0s = torch.cross(eqs, y0, dim=-1) + pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1] + pt_y0s_valid = is_inside_img(pt_y0s, img_shape) + + xW, yH = torch.tensor([1., 0, EPS - w], device=device), torch.tensor([0., 1, EPS - h], device=device) + xW = xW.repeat(eqs.shape[:-1] + (1,)) + yH = yH.repeat(eqs.shape[:-1] + (1,)) + pt_xWs = torch.cross(eqs, xW, dim=-1) + pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1] + pt_xWs_valid = is_inside_img(pt_xWs, img_shape) + pt_yHs = torch.cross(eqs, yH, dim=-1) + pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1] + pt_yHs_valid = is_inside_img(pt_yHs, img_shape) + + # If the X coordinate of the first endpoint is out + mask = (segs[..., 0, 0] < 0) & pt_x0s_valid + segs[mask, 0, :] = pt_x0s[mask] + mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid + segs[mask, 0, :] = pt_xWs[mask] + # If the X coordinate of the second endpoint is out + mask = (segs[..., 1, 0] < 0) & pt_x0s_valid + segs[mask, 1, :] = pt_x0s[mask] + mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid + segs[mask, 1, :] = pt_xWs[mask] + # If the Y coordinate of the first endpoint is out + mask = (segs[..., 0, 1] < 0) & pt_y0s_valid + segs[mask, 0, :] = pt_y0s[mask] + mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid + segs[mask, 0, :] = pt_yHs[mask] + # If the Y coordinate of the second endpoint is out + mask = (segs[..., 1, 1] < 0) & pt_y0s_valid + segs[mask, 1, :] = pt_y0s[mask] + mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid + segs[mask, 1, :] = pt_yHs[mask] + + assert torch.all(segs >= 0) and torch.all(segs[..., 0] < w) and torch.all(segs[..., 1] < h) + return segs + + +def warp_lines_torch(lines, H, inverse=True, dst_shape: Tuple[int, int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param lines: A tensor of shape (B, N, 2, 2) where B is the batch size, N the number of lines. + :param H: The homography used to convert the lines. batched or not (shapes (B, 8) and (8,) respectively). + :param inverse: Whether to apply H or the inverse of H + :param dst_shape:If provided, lines are trimmed to be inside the image + """ + device = lines.device + batch_size, n = lines.shape[:2] + lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(lines.shape) + + if dst_shape is None: + return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device) + + out_img = torch.any((lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1) + valid = ~out_img.all(-1) + any_out_of_img = out_img.any(-1) + lines_to_trim = valid & any_out_of_img + + for b in range(batch_size): + lines_to_trim_mask_b = lines_to_trim[b] + lines_to_trim_b = lines[b][lines_to_trim_mask_b] + corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape) + lines[b][lines_to_trim_mask_b] = corrected_lines + + return lines, valid diff --git a/imcui/third_party/GlueStick/gluestick/models/__init__.py b/imcui/third_party/GlueStick/gluestick/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/GlueStick/gluestick/models/base_model.py b/imcui/third_party/GlueStick/gluestick/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..30ca991655a28ca88074b42312c33b360f655fab --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/models/base_model.py @@ -0,0 +1,126 @@ +""" +Base class for trainable models. +""" + +from abc import ABCMeta, abstractmethod +import omegaconf +from omegaconf import OmegaConf +from torch import nn +from copy import copy + + +class MetaModel(ABCMeta): + def __prepare__(name, bases, **kwds): + total_conf = OmegaConf.create() + for base in bases: + for key in ('base_default_conf', 'default_conf'): + update = getattr(base, key, {}) + if isinstance(update, dict): + update = OmegaConf.create(update) + total_conf = OmegaConf.merge(total_conf, update) + return dict(base_default_conf=total_conf) + + +class BaseModel(nn.Module, metaclass=MetaModel): + """ + What the child model is expect to declare: + default_conf: dictionary of the default configuration of the model. + It recursively updates the default_conf of all parent classes, and + it is updated by the user-provided configuration passed to __init__. + Configurations can be nested. + + required_data_keys: list of expected keys in the input data dictionary. + + strict_conf (optional): boolean. If false, BaseModel does not raise + an error when the user provides an unknown configuration entry. + + _init(self, conf): initialization method, where conf is the final + configuration object (also accessible with `self.conf`). Accessing + unknown configuration entries will raise an error. + + _forward(self, data): method that returns a dictionary of batched + prediction tensors based on a dictionary of batched input data tensors. + + loss(self, pred, data): method that returns a dictionary of losses, + computed from model predictions and input data. Each loss is a batch + of scalars, i.e. a torch.Tensor of shape (B,). + The total loss to be optimized has the key `'total'`. + + metrics(self, pred, data): method that returns a dictionary of metrics, + each as a batch of scalars. + """ + default_conf = { + 'name': None, + 'trainable': True, # if false: do not optimize this model parameters + 'freeze_batch_normalization': False, # use test-time statistics + } + required_data_keys = [] + strict_conf = True + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + default_conf = OmegaConf.merge( + self.base_default_conf, OmegaConf.create(self.default_conf)) + if self.strict_conf: + OmegaConf.set_struct(default_conf, True) + + # fixme: backward compatibility + if 'pad' in conf and 'pad' not in default_conf: # backward compat. + with omegaconf.read_write(conf): + with omegaconf.open_dict(conf): + conf['interpolation'] = {'pad': conf.pop('pad')} + + if isinstance(conf, dict): + conf = OmegaConf.create(conf) + self.conf = conf = OmegaConf.merge(default_conf, conf) + OmegaConf.set_readonly(conf, True) + OmegaConf.set_struct(conf, True) + self.required_data_keys = copy(self.required_data_keys) + self._init(conf) + + if not conf.trainable: + for p in self.parameters(): + p.requires_grad = False + + def train(self, mode=True): + super().train(mode) + + def freeze_bn(module): + if isinstance(module, nn.modules.batchnorm._BatchNorm): + module.eval() + if self.conf.freeze_batch_normalization: + self.apply(freeze_bn) + + return self + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + def recursive_key_check(expected, given): + for key in expected: + assert key in given, f'Missing key {key} in data' + if isinstance(expected, dict): + recursive_key_check(expected[key], given[key]) + + recursive_key_check(self.required_data_keys, data) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def loss(self, pred, data): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def metrics(self, pred, data): + """To be implemented by the child class.""" + raise NotImplementedError diff --git a/imcui/third_party/GlueStick/gluestick/models/gluestick.py b/imcui/third_party/GlueStick/gluestick/models/gluestick.py new file mode 100644 index 0000000000000000000000000000000000000000..98550ff9d8918bcf49a13ae606d1d631448b8f96 --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/models/gluestick.py @@ -0,0 +1,563 @@ +import os.path +import warnings +from copy import deepcopy + +warnings.filterwarnings("ignore", category=UserWarning) +import torch +import torch.utils.checkpoint +from torch import nn +from .base_model import BaseModel + +ETH_EPS = 1e-8 + + +class GlueStick(BaseModel): + default_conf = { + 'input_dim': 256, + 'descriptor_dim': 256, + 'bottleneck_dim': None, + 'weights': None, + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, + 'num_line_iterations': 1, + 'line_attention': False, + 'filter_threshold': 0.2, + 'checkpointed': False, + 'skip_init': False, + 'inter_supervision': None, + 'loss': { + 'nll_weight': 1., + 'nll_balancing': 0.5, + 'reward_weight': 0., + 'bottleneck_l2_weight': 0., + 'dense_nll_weight': 0., + 'inter_supervision': [0.3, 0.6], + }, + } + required_data_keys = [ + 'keypoints0', 'keypoints1', + 'descriptors0', 'descriptors1', + 'keypoint_scores0', 'keypoint_scores1'] + + DEFAULT_LOSS_CONF = {'nll_weight': 1., 'nll_balancing': 0.5, 'reward_weight': 0., 'bottleneck_l2_weight': 0.} + + def _init(self, conf): + if conf.bottleneck_dim is not None: + self.bottleneck_down = nn.Conv1d( + conf.input_dim, conf.bottleneck_dim, kernel_size=1) + self.bottleneck_up = nn.Conv1d( + conf.bottleneck_dim, conf.input_dim, kernel_size=1) + nn.init.constant_(self.bottleneck_down.bias, 0.0) + nn.init.constant_(self.bottleneck_up.bias, 0.0) + + if conf.input_dim != conf.descriptor_dim: + self.input_proj = nn.Conv1d( + conf.input_dim, conf.descriptor_dim, kernel_size=1) + nn.init.constant_(self.input_proj.bias, 0.0) + + self.kenc = KeypointEncoder(conf.descriptor_dim, + conf.keypoint_encoder) + self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder) + self.gnn = AttentionalGNN(conf.descriptor_dim, conf.GNN_layers, + checkpointed=conf.checkpointed, + inter_supervision=conf.inter_supervision, + num_line_iterations=conf.num_line_iterations, + line_attention=conf.line_attention) + self.final_proj = nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, + kernel_size=1) + nn.init.constant_(self.final_proj.bias, 0.0) + nn.init.orthogonal_(self.final_proj.weight, gain=1) + self.final_line_proj = nn.Conv1d( + conf.descriptor_dim, conf.descriptor_dim, kernel_size=1) + nn.init.constant_(self.final_line_proj.bias, 0.0) + nn.init.orthogonal_(self.final_line_proj.weight, gain=1) + if conf.inter_supervision is not None: + self.inter_line_proj = nn.ModuleList( + [nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1) + for _ in conf.inter_supervision]) + self.layer2idx = {} + for i, l in enumerate(conf.inter_supervision): + nn.init.constant_(self.inter_line_proj[i].bias, 0.0) + nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1) + self.layer2idx[l] = i + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + line_bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('line_bin_score', line_bin_score) + + if conf.weights: + assert isinstance(conf.weights, str) + if os.path.exists(conf.weights): + state_dict = torch.load(conf.weights, map_location='cpu') + else: + weights_url = "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar" + state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') + if 'model' in state_dict: + state_dict = {k.replace('matcher.', ''): v for k, v in state_dict['model'].items() if 'matcher.' in k} + state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + self.load_state_dict(state_dict) + + def _forward(self, data): + device = data['keypoints0'].device + b_size = len(data['keypoints0']) + image_size0 = (data['image_size0'] if 'image_size0' in data + else data['image0'].shape) + image_size1 = (data['image_size1'] if 'image_size1' in data + else data['image1'].shape) + + pred = {} + desc0, desc1 = data['descriptors0'], data['descriptors1'] + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + + n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1] + n_lines0, n_lines1 = data['lines0'].shape[1], data['lines1'].shape[1] + if n_kpts0 == 0 or n_kpts1 == 0: + # No detected keypoints nor lines + pred['log_assignment'] = torch.zeros( + b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device) + pred['matches0'] = torch.full( + (b_size, n_kpts0), -1, device=device, dtype=torch.int64) + pred['matches1'] = torch.full( + (b_size, n_kpts1), -1, device=device, dtype=torch.int64) + pred['match_scores0'] = torch.zeros( + (b_size, n_kpts0), device=device, dtype=torch.float32) + pred['match_scores1'] = torch.zeros( + (b_size, n_kpts1), device=device, dtype=torch.float32) + pred['line_log_assignment'] = torch.zeros(b_size, n_lines0, n_lines1, + dtype=torch.float, device=device) + pred['line_matches0'] = torch.full((b_size, n_lines0), -1, + device=device, dtype=torch.int64) + pred['line_matches1'] = torch.full((b_size, n_lines1), -1, + device=device, dtype=torch.int64) + pred['line_match_scores0'] = torch.zeros( + (b_size, n_lines0), device=device, dtype=torch.float32) + pred['line_match_scores1'] = torch.zeros( + (b_size, n_kpts1), device=device, dtype=torch.float32) + return pred + + lines0 = data['lines0'].flatten(1, 2) + lines1 = data['lines1'].flatten(1, 2) + lines_junc_idx0 = data['lines_junc_idx0'].flatten(1, 2) # [b_size, num_lines * 2] + lines_junc_idx1 = data['lines_junc_idx1'].flatten(1, 2) + + if self.conf.bottleneck_dim is not None: + pred['down_descriptors0'] = desc0 = self.bottleneck_down(desc0) + pred['down_descriptors1'] = desc1 = self.bottleneck_down(desc1) + desc0 = self.bottleneck_up(desc0) + desc1 = self.bottleneck_up(desc1) + desc0 = nn.functional.normalize(desc0, p=2, dim=1) + desc1 = nn.functional.normalize(desc1, p=2, dim=1) + pred['bottleneck_descriptors0'] = desc0 + pred['bottleneck_descriptors1'] = desc1 + if self.conf.loss.nll_weight == 0: + desc0 = desc0.detach() + desc1 = desc1.detach() + + if self.conf.input_dim != self.conf.descriptor_dim: + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + + kpts0 = normalize_keypoints(kpts0, image_size0) + kpts1 = normalize_keypoints(kpts1, image_size1) + + assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1) + assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1) + desc0 = desc0 + self.kenc(kpts0, data['keypoint_scores0']) + desc1 = desc1 + self.kenc(kpts1, data['keypoint_scores1']) + + if n_lines0 != 0 and n_lines1 != 0: + # Pre-compute the line encodings + lines0 = normalize_keypoints(lines0, image_size0).reshape( + b_size, n_lines0, 2, 2) + lines1 = normalize_keypoints(lines1, image_size1).reshape( + b_size, n_lines1, 2, 2) + line_enc0 = self.lenc(lines0, data['line_scores0']) + line_enc1 = self.lenc(lines1, data['line_scores1']) + else: + line_enc0 = torch.zeros( + b_size, self.conf.descriptor_dim, n_lines0 * 2, + dtype=torch.float, device=device) + line_enc1 = torch.zeros( + b_size, self.conf.descriptor_dim, n_lines1 * 2, + dtype=torch.float, device=device) + + desc0, desc1 = self.gnn(desc0, desc1, line_enc0, line_enc1, + lines_junc_idx0, lines_junc_idx1) + + # Match all points (KP and line junctions) + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) + + kp_scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) + kp_scores = kp_scores / self.conf.descriptor_dim ** .5 + kp_scores = log_double_softmax(kp_scores, self.bin_score) + m0, m1, mscores0, mscores1 = self._get_matches(kp_scores) + pred['log_assignment'] = kp_scores + pred['matches0'] = m0 + pred['matches1'] = m1 + pred['match_scores0'] = mscores0 + pred['match_scores1'] = mscores1 + + # Match the lines + if n_lines0 > 0 and n_lines1 > 0: + (line_scores, m0_lines, m1_lines, mscores0_lines, + mscores1_lines, raw_line_scores) = self._get_line_matches( + desc0[:, :, :2 * n_lines0], desc1[:, :, :2 * n_lines1], + lines_junc_idx0, lines_junc_idx1, self.final_line_proj) + if self.conf.inter_supervision: + for l in self.conf.inter_supervision: + (line_scores_i, m0_lines_i, m1_lines_i, mscores0_lines_i, + mscores1_lines_i) = self._get_line_matches( + self.gnn.inter_layers[l][0][:, :, :2 * n_lines0], + self.gnn.inter_layers[l][1][:, :, :2 * n_lines1], + lines_junc_idx0, lines_junc_idx1, + self.inter_line_proj[self.layer2idx[l]]) + pred[f'line_{l}_log_assignment'] = line_scores_i + pred[f'line_{l}_matches0'] = m0_lines_i + pred[f'line_{l}_matches1'] = m1_lines_i + pred[f'line_{l}_match_scores0'] = mscores0_lines_i + pred[f'line_{l}_match_scores1'] = mscores1_lines_i + else: + line_scores = torch.zeros(b_size, n_lines0, n_lines1, + dtype=torch.float, device=device) + m0_lines = torch.full((b_size, n_lines0), -1, + device=device, dtype=torch.int64) + m1_lines = torch.full((b_size, n_lines1), -1, + device=device, dtype=torch.int64) + mscores0_lines = torch.zeros( + (b_size, n_lines0), device=device, dtype=torch.float32) + mscores1_lines = torch.zeros( + (b_size, n_lines1), device=device, dtype=torch.float32) + raw_line_scores = torch.zeros(b_size, n_lines0, n_lines1, + dtype=torch.float, device=device) + pred['line_log_assignment'] = line_scores + pred['line_matches0'] = m0_lines + pred['line_matches1'] = m1_lines + pred['line_match_scores0'] = mscores0_lines + pred['line_match_scores1'] = mscores1_lines + pred['raw_line_scores'] = raw_line_scores + + return pred + + def _get_matches(self, scores_mat): + max0 = scores_mat[:, :-1, :-1].max(2) + max1 = scores_mat[:, :-1, :-1].max(1) + m0, m1 = max0.indices, max1.indices + mutual0 = arange_like(m0, 1)[None] == m1.gather(1, m0) + mutual1 = arange_like(m1, 1)[None] == m0.gather(1, m1) + zero = scores_mat.new_tensor(0) + mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) + valid0 = mutual0 & (mscores0 > self.conf.filter_threshold) + valid1 = mutual1 & valid0.gather(1, m1) + m0 = torch.where(valid0, m0, m0.new_tensor(-1)) + m1 = torch.where(valid1, m1, m1.new_tensor(-1)) + return m0, m1, mscores0, mscores1 + + def _get_line_matches(self, ldesc0, ldesc1, lines_junc_idx0, + lines_junc_idx1, final_proj): + mldesc0 = final_proj(ldesc0) + mldesc1 = final_proj(ldesc1) + + line_scores = torch.einsum('bdn,bdm->bnm', mldesc0, mldesc1) + line_scores = line_scores / self.conf.descriptor_dim ** .5 + + # Get the line representation from the junction descriptors + n2_lines0 = lines_junc_idx0.shape[1] + n2_lines1 = lines_junc_idx1.shape[1] + line_scores = torch.gather( + line_scores, dim=2, + index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1)) + line_scores = torch.gather( + line_scores, dim=1, + index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1)) + line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, + n2_lines1 // 2, 2)) + + # Match either in one direction or the other + raw_line_scores = 0.5 * torch.maximum( + line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1], + line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0]) + line_scores = log_double_softmax(raw_line_scores, self.line_bin_score) + m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches( + line_scores) + return (line_scores, m0_lines, m1_lines, mscores0_lines, + mscores1_lines, raw_line_scores) + + def loss(self, pred, data): + raise NotImplementedError() + + def metrics(self, pred, data): + raise NotImplementedError() + + +def MLP(channels, do_bn=True): + n = len(channels) + layers = [] + for i in range(1, n): + layers.append( + nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) + if i < (n - 1): + if do_bn: + layers.append(nn.BatchNorm1d(channels[i])) + layers.append(nn.ReLU()) + return nn.Sequential(*layers) + + +def normalize_keypoints(kpts, shape_or_size): + if isinstance(shape_or_size, (tuple, list)): + # it's a shape + h, w = shape_or_size[-2:] + size = kpts.new_tensor([[w, h]]) + else: + # it's a size + assert isinstance(shape_or_size, torch.Tensor) + size = shape_or_size.to(kpts) + c = size / 2 + f = size.max(1, keepdim=True).values * 0.7 # somehow we used 0.7 for SG + return (kpts - c[:, None, :]) / f[:, None, :] + + +class KeypointEncoder(nn.Module): + def __init__(self, feature_dim, layers): + super().__init__() + self.encoder = MLP([3] + list(layers) + [feature_dim], do_bn=True) + nn.init.constant_(self.encoder[-1].bias, 0.0) + + def forward(self, kpts, scores): + inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] + return self.encoder(torch.cat(inputs, dim=1)) + + +class EndPtEncoder(nn.Module): + def __init__(self, feature_dim, layers): + super().__init__() + self.encoder = MLP([5] + list(layers) + [feature_dim], do_bn=True) + nn.init.constant_(self.encoder[-1].bias, 0.0) + + def forward(self, endpoints, scores): + # endpoints should be [B, N, 2, 2] + # output is [B, feature_dim, N * 2] + b_size, n_pts, _, _ = endpoints.shape + assert tuple(endpoints.shape[-2:]) == (2, 2) + endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2) + endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2) + endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2) + inputs = [endpoints.flatten(1, 2).transpose(1, 2), + endpt_offset, scores.repeat(1, 2).unsqueeze(1)] + return self.encoder(torch.cat(inputs, dim=1)) + + +@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) +def attention(query, key, value): + dim = query.shape[1] + scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 + prob = torch.nn.functional.softmax(scores, dim=-1) + return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model): + super().__init__() + assert d_model % h == 0 + self.dim = d_model // h + self.h = h + self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) + self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) + # self.prob = [] + + def forward(self, query, key, value): + b = query.size(0) + query, key, value = [l(x).view(b, self.dim, self.h, -1) + for l, x in zip(self.proj, (query, key, value))] + x, prob = attention(query, key, value) + # self.prob.append(prob.mean(dim=1)) + return self.merge(x.contiguous().view(b, self.dim * self.h, -1)) + + +class AttentionalPropagation(nn.Module): + def __init__(self, num_dim, num_heads, skip_init=False): + super().__init__() + self.attn = MultiHeadedAttention(num_heads, num_dim) + self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True) + nn.init.constant_(self.mlp[-1].bias, 0.0) + if skip_init: + self.register_parameter('scaling', nn.Parameter(torch.tensor(0.))) + else: + self.scaling = 1. + + def forward(self, x, source): + message = self.attn(x, source, source) + return self.mlp(torch.cat([x, message], dim=1)) * self.scaling + + +class GNNLayer(nn.Module): + def __init__(self, feature_dim, layer_type, skip_init): + super().__init__() + assert layer_type in ['cross', 'self'] + self.type = layer_type + self.update = AttentionalPropagation(feature_dim, 4, skip_init) + + def forward(self, desc0, desc1): + if self.type == 'cross': + src0, src1 = desc1, desc0 + elif self.type == 'self': + src0, src1 = desc0, desc1 + else: + raise ValueError("Unknown layer type: " + self.type) + # self.update.attn.prob = [] + delta0, delta1 = self.update(desc0, src0), self.update(desc1, src1) + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + return desc0, desc1 + + +class LineLayer(nn.Module): + def __init__(self, feature_dim, line_attention=False): + super().__init__() + self.dim = feature_dim + self.mlp = MLP([self.dim * 3, self.dim * 2, self.dim], do_bn=True) + self.line_attention = line_attention + if line_attention: + self.proj_node = nn.Conv1d(self.dim, self.dim, kernel_size=1) + self.proj_neigh = nn.Conv1d(2 * self.dim, self.dim, kernel_size=1) + + def get_endpoint_update(self, ldesc, line_enc, lines_junc_idx): + # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2] + # and lines_junc_idx [bs, n_lines * 2] + # Create one message per line endpoint + b_size = lines_junc_idx.shape[0] + line_desc = torch.gather( + ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1)) + message = torch.cat([ + line_desc, + line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(), + line_enc], dim=1) + return self.mlp(message) # [b_size, D, n_lines * 2] + + def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx): + # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2] + # and lines_junc_idx [bs, n_lines * 2] + b_size = lines_junc_idx.shape[0] + expanded_lines_junc_idx = lines_junc_idx[:, None].repeat(1, self.dim, 1) + + # Query: desc of the current node + query = self.proj_node(ldesc) # [b_size, D, n_junc] + query = torch.gather(query, 2, expanded_lines_junc_idx) + # query is [b_size, D, n_lines * 2] + + # Key: combination of neighboring desc and line encodings + line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx) + key = self.proj_neigh(torch.cat([ + line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(), + line_enc], dim=1)) # [b_size, D, n_lines * 2] + + # Compute the attention weights with a custom softmax per junction + prob = (query * key).sum(dim=1) / self.dim ** .5 # [b_size, n_lines * 2] + prob = torch.exp(prob - prob.max()) + denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_( + dim=1, index=lines_junc_idx, + src=prob, reduce='sum', include_self=False) # [b_size, n_junc] + denom = torch.gather(denom, 1, lines_junc_idx) # [b_size, n_lines * 2] + prob = prob / (denom + ETH_EPS) + return prob # [b_size, n_lines * 2] + + def forward(self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, + lines_junc_idx1): + # Gather the endpoint updates + lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0) + lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1) + + update0, update1 = torch.zeros_like(ldesc0), torch.zeros_like(ldesc1) + dim = ldesc0.shape[1] + if self.line_attention: + # Compute an attention for each neighbor and do a weighted average + prob0 = self.get_endpoint_attention(ldesc0, line_enc0, + lines_junc_idx0) + lupdate0 = lupdate0 * prob0[:, None] + update0 = update0.scatter_reduce_( + dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1), + src=lupdate0, reduce='sum', include_self=False) + prob1 = self.get_endpoint_attention(ldesc1, line_enc1, + lines_junc_idx1) + lupdate1 = lupdate1 * prob1[:, None] + update1 = update1.scatter_reduce_( + dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1), + src=lupdate1, reduce='sum', include_self=False) + else: + # Average the updates for each junction (requires torch > 1.12) + update0 = update0.scatter_reduce_( + dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1), + src=lupdate0, reduce='mean', include_self=False) + update1 = update1.scatter_reduce_( + dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1), + src=lupdate1, reduce='mean', include_self=False) + + # Update + ldesc0 = ldesc0 + update0 + ldesc1 = ldesc1 + update1 + + return ldesc0, ldesc1 + + +class AttentionalGNN(nn.Module): + def __init__(self, feature_dim, layer_types, checkpointed=False, + skip=False, inter_supervision=None, num_line_iterations=1, + line_attention=False): + super().__init__() + self.checkpointed = checkpointed + self.inter_supervision = inter_supervision + self.num_line_iterations = num_line_iterations + self.inter_layers = {} + self.layers = nn.ModuleList([ + GNNLayer(feature_dim, layer_type, skip) + for layer_type in layer_types]) + self.line_layers = nn.ModuleList( + [LineLayer(feature_dim, line_attention) + for _ in range(len(layer_types) // 2)]) + + def forward(self, desc0, desc1, line_enc0, line_enc1, + lines_junc_idx0, lines_junc_idx1): + for i, layer in enumerate(self.layers): + if self.checkpointed: + desc0, desc1 = torch.utils.checkpoint.checkpoint( + layer, desc0, desc1, preserve_rng_state=False) + else: + desc0, desc1 = layer(desc0, desc1) + if (layer.type == 'self' and lines_junc_idx0.shape[1] > 0 + and lines_junc_idx1.shape[1] > 0): + # Add line self attention layers after every self layer + for _ in range(self.num_line_iterations): + if self.checkpointed: + desc0, desc1 = torch.utils.checkpoint.checkpoint( + self.line_layers[i // 2], desc0, desc1, line_enc0, + line_enc1, lines_junc_idx0, lines_junc_idx1, + preserve_rng_state=False) + else: + desc0, desc1 = self.line_layers[i // 2]( + desc0, desc1, line_enc0, line_enc1, + lines_junc_idx0, lines_junc_idx1) + + # Optionally store the line descriptor at intermediate layers + if (self.inter_supervision is not None + and (i // 2) in self.inter_supervision + and layer.type == 'cross'): + self.inter_layers[i // 2] = (desc0.clone(), desc1.clone()) + return desc0, desc1 + + +def log_double_softmax(scores, bin_score): + b, m, n = scores.shape + bin_ = bin_score[None, None, None] + scores0 = torch.cat([scores, bin_.expand(b, m, 1)], 2) + scores1 = torch.cat([scores, bin_.expand(b, 1, n)], 1) + scores0 = torch.nn.functional.log_softmax(scores0, 2) + scores1 = torch.nn.functional.log_softmax(scores1, 1) + scores = scores.new_full((b, m + 1, n + 1), 0) + scores[:, :m, :n] = (scores0[:, :, :n] + scores1[:, :m, :]) / 2 + scores[:, :-1, -1] = scores0[:, :, -1] + scores[:, -1, :-1] = scores1[:, -1, :] + return scores + + +def arange_like(x, dim): + return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 diff --git a/imcui/third_party/GlueStick/gluestick/models/superpoint.py b/imcui/third_party/GlueStick/gluestick/models/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..872063275f4fde27f552bf2c2674dc60d5220ec9 --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/models/superpoint.py @@ -0,0 +1,229 @@ +""" +Inference model of SuperPoint, a feature detector and descriptor. + +Described in: + SuperPoint: Self-Supervised Interest Point Detection and Description, + Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018. + +Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork +""" + +import torch +from torch import nn + +from .. import GLUESTICK_ROOT +from ..models.base_model import BaseModel + + +def simple_nms(scores, radius): + """Perform non maximum suppression on the heatmap using max-pooling. + This method does not suppress contiguous points that have the same score. + Args: + scores: the score heatmap of size `(B, H, W)`. + size: an interger scalar, the radius of the NMS window. + """ + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=radius * 2 + 1, stride=1, padding=radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def remove_borders(keypoints, scores, b, h, w): + mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b)) + mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints, scores, k): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0, sorted=True) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s): + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(keypoints)[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {'align_corners': True} if torch.__version__ >= '1.3' else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + return descriptors + + +class SuperPoint(BaseModel): + default_conf = { + 'has_detector': True, + 'has_descriptor': True, + 'descriptor_dim': 256, + + # Inference + 'return_all': False, + 'sparse_outputs': True, + 'nms_radius': 4, + 'detection_threshold': 0.005, + 'max_num_keypoints': -1, + 'force_num_keypoints': False, + 'remove_borders': 4, + } + required_data_keys = ['image'] + + def _init(self, conf): + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + + if conf.has_detector: + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + + if conf.has_descriptor: + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d( + c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0) + + path = GLUESTICK_ROOT / 'resources' / 'weights' / 'superpoint_v1.pth' + if path.exists(): + weights = torch.load(str(path), map_location='cpu') + else: + weights_url = "https://github.com/cvg/GlueStick/raw/main/resources/weights/superpoint_v1.pth" + weights = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') + self.load_state_dict(weights, strict=False) + + def _forward(self, data): + image = data['image'] + if image.shape[1] == 3: # RGB + scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) + image = (image * scale).sum(1, keepdim=True) + + # Shared Encoder + x = self.relu(self.conv1a(image)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + pred = {} + if self.conf.has_detector and self.conf.max_num_keypoints != 0: + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, c, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + pred['keypoint_scores'] = dense_scores = scores + if self.conf.has_descriptor: + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + all_desc = self.convDb(cDa) + all_desc = torch.nn.functional.normalize(all_desc, p=2, dim=1) + pred['descriptors'] = all_desc + + if self.conf.max_num_keypoints == 0: # Predict dense descriptors only + b_size = len(image) + device = image.device + return { + 'keypoints': torch.empty(b_size, 0, 2, device=device), + 'keypoint_scores': torch.empty(b_size, 0, device=device), + 'descriptors': torch.empty(b_size, self.conf.descriptor_dim, 0, device=device), + 'all_descriptors': all_desc + } + + if self.conf.sparse_outputs: + assert self.conf.has_detector and self.conf.has_descriptor + + scores = simple_nms(scores, self.conf.nms_radius) + + # Extract keypoints + keypoints = [ + torch.nonzero(s > self.conf.detection_threshold) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with highest score + if self.conf.max_num_keypoints > 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, self.conf.max_num_keypoints) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + if self.conf.force_num_keypoints: + _, _, h, w = data['image'].shape + assert self.conf.max_num_keypoints > 0 + scores = list(scores) + for i in range(len(keypoints)): + k, s = keypoints[i], scores[i] + missing = self.conf.max_num_keypoints - len(k) + if missing > 0: + new_k = torch.rand(missing, 2).to(k) + new_k = new_k * k.new_tensor([[w - 1, h - 1]]) + new_s = torch.zeros(missing).to(s) + keypoints[i] = torch.cat([k, new_k], 0) + scores[i] = torch.cat([s, new_s], 0) + + # Extract descriptors + desc = [sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, all_desc)] + + if (len(keypoints) == 1) or self.conf.force_num_keypoints: + keypoints = torch.stack(keypoints, 0) + scores = torch.stack(scores, 0) + desc = torch.stack(desc, 0) + + pred = { + 'keypoints': keypoints, + 'keypoint_scores': scores, + 'descriptors': desc, + } + + if self.conf.return_all: + pred['all_descriptors'] = all_desc + pred['dense_score'] = dense_scores + else: + del all_desc + torch.cuda.empty_cache() + + return pred + + def loss(self, pred, data): + raise NotImplementedError + + def metrics(self, pred, data): + raise NotImplementedError diff --git a/imcui/third_party/GlueStick/gluestick/models/two_view_pipeline.py b/imcui/third_party/GlueStick/gluestick/models/two_view_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e21c1f62e2bd4ad573ebb87ea5635742b5032e --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/models/two_view_pipeline.py @@ -0,0 +1,176 @@ +""" +A two-view sparse feature matching pipeline. + +This model contains sub-models for each step: + feature extraction, feature matching, outlier filtering, pose estimation. +Each step is optional, and the features or matches can be provided as input. +Default: SuperPoint with nearest neighbor matching. + +Convention for the matches: m0[i] is the index of the keypoint in image 1 +that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. +""" + +import numpy as np +import torch + +from .. import get_model +from .base_model import BaseModel + + +def keep_quadrant_kp_subset(keypoints, scores, descs, h, w): + """Keep only keypoints in one of the four quadrant of the image.""" + h2, w2 = h // 2, w // 2 + w_x = np.random.choice([0, w2]) + w_y = np.random.choice([0, h2]) + valid_mask = ((keypoints[..., 0] >= w_x) + & (keypoints[..., 0] < w_x + w2) + & (keypoints[..., 1] >= w_y) + & (keypoints[..., 1] < w_y + h2)) + keypoints = keypoints[valid_mask][None] + scores = scores[valid_mask][None] + descs = descs.permute(0, 2, 1)[valid_mask].t()[None] + return keypoints, scores, descs + + +def keep_random_kp_subset(keypoints, scores, descs, num_selected): + """Keep a random subset of keypoints.""" + num_kp = keypoints.shape[1] + selected_kp = torch.randperm(num_kp)[:num_selected] + keypoints = keypoints[:, selected_kp] + scores = scores[:, selected_kp] + descs = descs[:, :, selected_kp] + return keypoints, scores, descs + + +def keep_best_kp_subset(keypoints, scores, descs, num_selected): + """Keep the top num_selected best keypoints.""" + sorted_indices = torch.sort(scores, dim=1)[1] + selected_kp = sorted_indices[:, -num_selected:] + keypoints = torch.gather(keypoints, 1, + selected_kp[:, :, None].repeat(1, 1, 2)) + scores = torch.gather(scores, 1, selected_kp) + descs = torch.gather(descs, 2, + selected_kp[:, None].repeat(1, descs.shape[1], 1)) + return keypoints, scores, descs + + +class TwoViewPipeline(BaseModel): + default_conf = { + 'extractor': { + 'name': 'superpoint', + 'trainable': False, + }, + 'use_lines': False, + 'use_points': True, + 'randomize_num_kp': False, + 'detector': {'name': None}, + 'descriptor': {'name': None}, + 'matcher': {'name': 'nearest_neighbor_matcher'}, + 'filter': {'name': None}, + 'solver': {'name': None}, + 'ground_truth': { + 'from_pose_depth': False, + 'from_homography': False, + 'th_positive': 3, + 'th_negative': 5, + 'reward_positive': 1, + 'reward_negative': -0.25, + 'is_likelihood_soft': True, + 'p_random_occluders': 0, + 'n_line_sampled_pts': 50, + 'line_perp_dist_th': 5, + 'overlap_th': 0.2, + 'min_visibility_th': 0.5 + }, + } + required_data_keys = ['image0', 'image1'] + strict_conf = False # need to pass new confs to children models + components = [ + 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver'] + + def _init(self, conf): + if conf.extractor.name: + self.extractor = get_model(conf.extractor.name)(conf.extractor) + else: + if self.conf.detector.name: + self.detector = get_model(conf.detector.name)(conf.detector) + else: + self.required_data_keys += ['keypoints0', 'keypoints1'] + if self.conf.descriptor.name: + self.descriptor = get_model(conf.descriptor.name)( + conf.descriptor) + else: + self.required_data_keys += ['descriptors0', 'descriptors1'] + + if conf.matcher.name: + self.matcher = get_model(conf.matcher.name)(conf.matcher) + else: + self.required_data_keys += ['matches0'] + + if conf.filter.name: + self.filter = get_model(conf.filter.name)(conf.filter) + + if conf.solver.name: + self.solver = get_model(conf.solver.name)(conf.solver) + + def _forward(self, data): + + def process_siamese(data, i): + data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} + if self.conf.extractor.name: + pred_i = self.extractor(data_i) + else: + pred_i = {} + if self.conf.detector.name: + pred_i = self.detector(data_i) + else: + for k in ['keypoints', 'keypoint_scores', 'descriptors', + 'lines', 'line_scores', 'line_descriptors', + 'valid_lines']: + if k in data_i: + pred_i[k] = data_i[k] + if self.conf.descriptor.name: + pred_i = { + **pred_i, **self.descriptor({**data_i, **pred_i})} + return pred_i + + pred0 = process_siamese(data, '0') + pred1 = process_siamese(data, '1') + + pred = {**{k + '0': v for k, v in pred0.items()}, + **{k + '1': v for k, v in pred1.items()}} + + if self.conf.matcher.name: + pred = {**pred, **self.matcher({**data, **pred})} + + if self.conf.filter.name: + pred = {**pred, **self.filter({**data, **pred})} + + if self.conf.solver.name: + pred = {**pred, **self.solver({**data, **pred})} + + return pred + + def loss(self, pred, data): + losses = {} + total = 0 + for k in self.components: + if self.conf[k].name: + try: + losses_ = getattr(self, k).loss(pred, {**pred, **data}) + except NotImplementedError: + continue + losses = {**losses, **losses_} + total = losses_['total'] + total + return {**losses, 'total': total} + + def metrics(self, pred, data): + metrics = {} + for k in self.components: + if self.conf[k].name: + try: + metrics_ = getattr(self, k).metrics(pred, {**pred, **data}) + except NotImplementedError: + continue + metrics = {**metrics, **metrics_} + return metrics diff --git a/imcui/third_party/GlueStick/gluestick/models/wireframe.py b/imcui/third_party/GlueStick/gluestick/models/wireframe.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3dd9873c6fdb4edcb4c75a103673ee2cb3b3fa --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/models/wireframe.py @@ -0,0 +1,274 @@ +import numpy as np +import torch +from pytlsd import lsd +from sklearn.cluster import DBSCAN + +from .base_model import BaseModel +from .superpoint import SuperPoint, sample_descriptors +from ..geometry import warp_lines_torch + + +def lines_to_wireframe(lines, line_scores, all_descs, conf): + """ Given a set of lines, their score and dense descriptors, + merge close-by endpoints and compute a wireframe defined by + its junctions and connectivity. + Returns: + junctions: list of [num_junc, 2] tensors listing all wireframe junctions + junc_scores: list of [num_junc] tensors with the junction score + junc_descs: list of [dim, num_junc] tensors with the junction descriptors + connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected + new_lines: the new set of [b_size, num_lines, 2, 2] lines + lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint + num_true_junctions: a list of the number of valid junctions for each image in the batch, + i.e. before filling with random ones + """ + b_size, _, _, _ = all_descs.shape + device = lines.device + endpoints = lines.reshape(b_size, -1, 2) + + (junctions, junc_scores, junc_descs, connectivity, new_lines, + lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], [] + for bs in range(b_size): + # Cluster the junctions that are close-by + db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit( + endpoints[bs].cpu().numpy()) + clusters = db.labels_ + n_clusters = len(set(clusters)) + num_true_junctions.append(n_clusters) + + # Compute the average junction and score for each cluster + clusters = torch.tensor(clusters, dtype=torch.long, + device=device) + new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, + device=device) + new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2), + endpoints[bs], reduce='mean', + include_self=False) + junctions.append(new_junc) + new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device) + new_scores.scatter_reduce_( + 0, clusters, torch.repeat_interleave(line_scores[bs], 2), + reduce='mean', include_self=False) + junc_scores.append(new_scores) + + # Compute the new lines + new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2)) + lines_junc_idx.append(clusters.reshape(-1, 2)) + + # Compute the junction connectivity + junc_connect = torch.eye(n_clusters, dtype=torch.bool, + device=device) + pairs = clusters.reshape(-1, 2) # these pairs are connected by a line + junc_connect[pairs[:, 0], pairs[:, 1]] = True + junc_connect[pairs[:, 1], pairs[:, 0]] = True + connectivity.append(junc_connect) + + # Interpolate the new junction descriptors + junc_descs.append(sample_descriptors( + junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0]) + + new_lines = torch.stack(new_lines, dim=0) + lines_junc_idx = torch.stack(lines_junc_idx, dim=0) + return (junctions, junc_scores, junc_descs, connectivity, + new_lines, lines_junc_idx, num_true_junctions) + + +class SPWireframeDescriptor(BaseModel): + default_conf = { + 'sp_params': { + 'has_detector': True, + 'has_descriptor': True, + 'descriptor_dim': 256, + 'trainable': False, + + # Inference + 'return_all': True, + 'sparse_outputs': True, + 'nms_radius': 4, + 'detection_threshold': 0.005, + 'max_num_keypoints': 1000, + 'force_num_keypoints': True, + 'remove_borders': 4, + }, + 'wireframe_params': { + 'merge_points': True, + 'merge_line_endpoints': True, + 'nms_radius': 3, + 'max_n_junctions': 500, + }, + 'max_n_lines': 250, + 'min_length': 15, + } + required_data_keys = ['image'] + + def _init(self, conf): + self.conf = conf + self.sp = SuperPoint(conf.sp_params) + + def detect_lsd_lines(self, x, max_n_lines=None): + if max_n_lines is None: + max_n_lines = self.conf.max_n_lines + lines, scores, valid_lines = [], [], [] + for b in range(len(x)): + # For each image on batch + img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8) + if max_n_lines is None: + b_segs = lsd(img) + else: + for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]: + b_segs = lsd(img, scale=s) + if len(b_segs) >= max_n_lines: + break + + segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1) + # Remove short lines + b_segs = b_segs[segs_length >= self.conf.min_length] + segs_length = segs_length[segs_length >= self.conf.min_length] + b_scores = b_segs[:, -1] * np.sqrt(segs_length) + # Take the most relevant segments with + indices = np.argsort(-b_scores) + if max_n_lines is not None: + indices = indices[:max_n_lines] + lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2))) + scores.append(torch.from_numpy(b_scores[indices])) + valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool)) + + lines = torch.stack(lines).to(x) + scores = torch.stack(scores).to(x) + valid_lines = torch.stack(valid_lines).to(x.device) + return lines, scores, valid_lines + + def _forward(self, data): + b_size, _, h, w = data['image'].shape + device = data['image'].device + + if not self.conf.sp_params.force_num_keypoints: + assert b_size == 1, "Only batch size of 1 accepted for non padded inputs" + + # Line detection + if 'lines' not in data or 'line_scores' not in data: + if 'original_img' in data: + # Detect more lines, because when projecting them to the image most of them will be discarded + lines, line_scores, valid_lines = self.detect_lsd_lines( + data['original_img'], self.conf.max_n_lines * 3) + # Apply the same transformation that is applied in homography_adaptation + lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:]) + valid_lines = valid_lines & valid_lines2 + lines[~valid_lines] = -1 + line_scores[~valid_lines] = 0 + # Re-sort the line segments to pick the ones that are inside the image and have bigger score + sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True) + line_scores = sorted_scores[:, :self.conf.max_n_lines] + sorting_indices = sorting_indices[:, :self.conf.max_n_lines] + lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1) + valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1) + else: + lines, line_scores, valid_lines = self.detect_lsd_lines(data['image']) + + else: + lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines'] + if line_scores.shape[-1] != 0: + line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None]) + + # SuperPoint prediction + pred = self.sp(data) + + # Remove keypoints that are too close to line endpoints + if self.conf.wireframe_params.merge_points: + kp = pred['keypoints'] + line_endpts = lines.reshape(b_size, -1, 2) + dist_pt_lines = torch.norm( + kp[:, :, None] - line_endpts[:, None], dim=-1) + # For each keypoint, mark it as valid or to remove + pts_to_remove = torch.any( + dist_pt_lines < self.conf.sp_params.nms_radius, dim=2) + # Simply remove them (we assume batch_size = 1 here) + assert len(kp) == 1 + pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None] + pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None] + pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None] + + # Connect the lines together to form a wireframe + orig_lines = lines.clone() + if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0: + # Merge first close-by endpoints to connect lines + (line_points, line_pts_scores, line_descs, line_association, + lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe( + lines, line_scores, pred['all_descriptors'], + conf=self.conf.wireframe_params) + + # Add the keypoints to the junctions and fill the rest with random keypoints + (all_points, all_scores, all_descs, + pl_associativity) = [], [], [], [] + for bs in range(b_size): + all_points.append(torch.cat( + [line_points[bs], pred['keypoints'][bs]], dim=0)) + all_scores.append(torch.cat( + [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0)) + all_descs.append(torch.cat( + [line_descs[bs], pred['descriptors'][bs]], dim=1)) + + associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device) + associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \ + line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]] + pl_associativity.append(associativity) + + all_points = torch.stack(all_points, dim=0) + all_scores = torch.stack(all_scores, dim=0) + all_descs = torch.stack(all_descs, dim=0) + pl_associativity = torch.stack(pl_associativity, dim=0) + else: + # Lines are independent + all_points = torch.cat([lines.reshape(b_size, -1, 2), + pred['keypoints']], dim=1) + n_pts = all_points.shape[1] + num_lines = lines.shape[1] + num_true_junctions = [num_lines * 2] * b_size + all_scores = torch.cat([ + torch.repeat_interleave(line_scores, 2, dim=1), + pred['keypoint_scores']], dim=1) + pred['line_descriptors'] = self.endpoints_pooling( + lines, pred['all_descriptors'], (h, w)) + all_descs = torch.cat([ + pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1), + pred['descriptors']], dim=2) + pl_associativity = torch.eye( + n_pts, dtype=torch.bool, + device=device)[None].repeat(b_size, 1, 1) + lines_junc_idx = torch.arange( + num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1) + + del pred['all_descriptors'] # Remove dense descriptors to save memory + torch.cuda.empty_cache() + + return {'keypoints': all_points, + 'keypoint_scores': all_scores, + 'descriptors': all_descs, + 'pl_associativity': pl_associativity, + 'num_junctions': torch.tensor(num_true_junctions), + 'lines': lines, + 'orig_lines': orig_lines, + 'lines_junc_idx': lines_junc_idx, + 'line_scores': line_scores, + 'valid_lines': valid_lines} + + @staticmethod + def endpoints_pooling(segs, all_descriptors, img_shape): + assert segs.ndim == 4 and segs.shape[-2:] == (2, 2) + filter_shape = all_descriptors.shape[-2:] + scale_x = filter_shape[1] / img_shape[1] + scale_y = filter_shape[0] / img_shape[0] + + scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long() + scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1) + scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1) + line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])] + for b, b_segs in enumerate(scaled_segs)] + line_descriptors = torch.cat(line_descriptors) + return line_descriptors # Shape (1, 256, 308, 2) + + def loss(self, pred, data): + raise NotImplementedError + + def metrics(self, pred, data): + return {} diff --git a/imcui/third_party/GlueStick/gluestick/run.py b/imcui/third_party/GlueStick/gluestick/run.py new file mode 100644 index 0000000000000000000000000000000000000000..85fd8af801dd18936163ac1af6d331f54965bfa5 --- /dev/null +++ b/imcui/third_party/GlueStick/gluestick/run.py @@ -0,0 +1,107 @@ +import argparse +import os +from os.path import join + +import cv2 +import torch +from matplotlib import pyplot as plt + +from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT +from .drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches +from .models.two_view_pipeline import TwoViewPipeline + + +def main(): + # Parse input parameters + parser = argparse.ArgumentParser( + prog='GlueStick Demo', + description='Demo app to show the point and line matches obtained by GlueStick') + parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg')) + parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg')) + parser.add_argument('--max_pts', type=int, default=1000) + parser.add_argument('--max_lines', type=int, default=300) + parser.add_argument('--skip-imshow', default=False, action='store_true') + args = parser.parse_args() + + # Evaluation config + conf = { + 'name': 'two_view_pipeline', + 'use_lines': True, + 'extractor': { + 'name': 'wireframe', + 'sp_params': { + 'force_num_keypoints': False, + 'max_num_keypoints': args.max_pts, + }, + 'wireframe_params': { + 'merge_points': True, + 'merge_line_endpoints': True, + }, + 'max_n_lines': args.max_lines, + }, + 'matcher': { + 'name': 'gluestick', + 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), + 'trainable': False, + }, + 'ground_truth': { + 'from_pose_depth': False, + } + } + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + pipeline_model = TwoViewPipeline(conf).to(device).eval() + + gray0 = cv2.imread(args.img1, 0) + gray1 = cv2.imread(args.img2, 0) + + torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) + torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] + x = {'image0': torch_gray0, 'image1': torch_gray1} + pred = pipeline_model(x) + + pred = batch_to_np(pred) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] + m0 = pred["matches0"] + + line_seg0, line_seg1 = pred["lines0"], pred["lines1"] + line_matches = pred["line_matches0"] + + valid_matches = m0 != -1 + match_indices = m0[valid_matches] + matched_kps0 = kp0[valid_matches] + matched_kps1 = kp1[match_indices] + + valid_matches = line_matches != -1 + match_indices = line_matches[valid_matches] + matched_lines0 = line_seg0[valid_matches] + matched_lines1 = line_seg1[match_indices] + + # Plot the matches + img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) + plot_images([img0, img1], ['Image 1 - detected lines', 'Image 2 - detected lines'], dpi=200, pad=2.0) + plot_lines([line_seg0, line_seg1], ps=4, lw=2) + plt.gcf().canvas.manager.set_window_title('Detected Lines') + plt.savefig('detected_lines.png') + + plot_images([img0, img1], ['Image 1 - detected points', 'Image 2 - detected points'], dpi=200, pad=2.0) + plot_keypoints([kp0, kp1], colors='c') + plt.gcf().canvas.manager.set_window_title('Detected Points') + plt.savefig('detected_points.png') + + plot_images([img0, img1], ['Image 1 - line matches', 'Image 2 - line matches'], dpi=200, pad=2.0) + plot_color_line_matches([matched_lines0, matched_lines1], lw=2) + plt.gcf().canvas.manager.set_window_title('Line Matches') + plt.savefig('line_matches.png') + + plot_images([img0, img1], ['Image 1 - point matches', 'Image 2 - point matches'], dpi=200, pad=2.0) + plot_matches(matched_kps0, matched_kps1, 'green', lw=1, ps=0) + plt.gcf().canvas.manager.set_window_title('Point Matches') + plt.savefig('point_matches.png') + if not args.skip_imshow: + plt.show() + + +if __name__ == '__main__': + main() diff --git a/imcui/third_party/LightGlue/.github/workflows/code-quality.yml b/imcui/third_party/LightGlue/.github/workflows/code-quality.yml new file mode 100644 index 0000000000000000000000000000000000000000..368b225f17a52121ddb6626ca7d9699c6538fcb3 --- /dev/null +++ b/imcui/third_party/LightGlue/.github/workflows/code-quality.yml @@ -0,0 +1,24 @@ +name: Format and Lint Checks +on: + push: + branches: + - main + paths: + - '*.py' + pull_request: + types: [ assigned, opened, synchronize, reopened ] +jobs: + check: + name: Format and Lint Checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'pip' + - run: python -m pip install --upgrade pip + - run: python -m pip install .[dev] + - run: python -m flake8 . + - run: python -m isort . --check-only --diff + - run: python -m black . --check --diff diff --git a/imcui/third_party/LightGlue/benchmark.py b/imcui/third_party/LightGlue/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..b160f3a37bf64d2a42884ea29f165fb3f325b9cf --- /dev/null +++ b/imcui/third_party/LightGlue/benchmark.py @@ -0,0 +1,255 @@ +# Benchmark script for LightGlue on real images +import argparse +import time +from collections import defaultdict +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch._dynamo + +from lightglue import LightGlue, SuperPoint +from lightglue.utils import load_image + +torch.set_grad_enabled(False) + + +def measure(matcher, data, device="cuda", r=100): + timings = np.zeros((r, 1)) + if device.type == "cuda": + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + # warmup + for _ in range(10): + _ = matcher(data) + # measurements + with torch.no_grad(): + for rep in range(r): + if device.type == "cuda": + starter.record() + _ = matcher(data) + ender.record() + # sync gpu + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + else: + start = time.perf_counter() + _ = matcher(data) + curr_time = (time.perf_counter() - start) * 1e3 + timings[rep] = curr_time + mean_syn = np.sum(timings) / r + std_syn = np.std(timings) + return {"mean": mean_syn, "std": std_syn} + + +def print_as_table(d, title, cnames): + print() + header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames]) + print(header) + print("-" * len(header)) + for k, l in d.items(): + print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l])) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark script for LightGlue") + parser.add_argument( + "--device", + choices=["auto", "cuda", "cpu", "mps"], + default="auto", + help="device to benchmark on", + ) + parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs") + parser.add_argument( + "--no_flash", action="store_true", help="disable FlashAttention" + ) + parser.add_argument( + "--no_prune_thresholds", + action="store_true", + help="disable pruning thresholds (i.e. always do pruning)", + ) + parser.add_argument( + "--add_superglue", + action="store_true", + help="add SuperGlue to the benchmark (requires hloc)", + ) + parser.add_argument( + "--measure", default="time", choices=["time", "log-time", "throughput"] + ) + parser.add_argument( + "--repeat", "--r", type=int, default=100, help="repetitions of measurements" + ) + parser.add_argument( + "--num_keypoints", + nargs="+", + type=int, + default=[256, 512, 1024, 2048, 4096], + help="number of keypoints (list separated by spaces)", + ) + parser.add_argument( + "--matmul_precision", default="highest", choices=["highest", "high", "medium"] + ) + parser.add_argument( + "--save", default=None, type=str, help="path where figure should be saved" + ) + args = parser.parse_intermixed_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if args.device != "auto": + device = torch.device(args.device) + + print("Running benchmark on device:", device) + + images = Path("assets") + inputs = { + "easy": ( + load_image(images / "DSC_0411.JPG"), + load_image(images / "DSC_0410.JPG"), + ), + "difficult": ( + load_image(images / "sacre_coeur1.jpg"), + load_image(images / "sacre_coeur2.jpg"), + ), + } + + configs = { + "LightGlue-full": { + "depth_confidence": -1, + "width_confidence": -1, + }, + # 'LG-prune': { + # 'width_confidence': -1, + # }, + # 'LG-depth': { + # 'depth_confidence': -1, + # }, + "LightGlue-adaptive": {}, + } + + if args.compile: + configs = {**configs, **{k + "-compile": v for k, v in configs.items()}} + + sg_configs = { + # 'SuperGlue': {}, + "SuperGlue-fast": {"sinkhorn_iterations": 5} + } + + torch.set_float32_matmul_precision(args.matmul_precision) + + results = {k: defaultdict(list) for k, v in inputs.items()} + + extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1) + extractor = extractor.eval().to(device) + figsize = (len(inputs) * 4.5, 4.5) + fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize) + axes = axes if len(inputs) > 1 else [axes] + fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})") + + for title, ax in zip(inputs.keys(), axes): + ax.set_xscale("log", base=2) + bases = [2**x for x in range(7, 16)] + ax.set_xticks(bases, bases) + ax.grid(which="major") + if args.measure == "log-time": + ax.set_yscale("log") + yticks = [10**x for x in range(6)] + ax.set_yticks(yticks, yticks) + mpos = [10**x * i for x in range(6) for i in range(2, 10)] + mlabel = [ + 10**x * i if i in [2, 5] else None + for x in range(6) + for i in range(2, 10) + ] + ax.set_yticks(mpos, mlabel, minor=True) + ax.grid(which="minor", linewidth=0.2) + ax.set_title(title) + + ax.set_xlabel("# keypoints") + if args.measure == "throughput": + ax.set_ylabel("Throughput [pairs/s]") + else: + ax.set_ylabel("Latency [ms]") + + for name, conf in configs.items(): + print("Run benchmark for:", name) + torch.cuda.empty_cache() + matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf) + if args.no_prune_thresholds: + matcher.pruning_keypoint_thresholds = { + k: -1 for k in matcher.pruning_keypoint_thresholds + } + matcher = matcher.eval().to(device) + if name.endswith("compile"): + import torch._dynamo + + torch._dynamo.reset() # avoid buffer overflow + matcher.compile() + for pair_name, ax in zip(inputs.keys(), axes): + image0, image1 = [x.to(device) for x in inputs[pair_name]] + runtimes = [] + for num_kpts in args.num_keypoints: + extractor.conf.max_num_keypoints = num_kpts + feats0 = extractor.extract(image0) + feats1 = extractor.extract(image1) + runtime = measure( + matcher, + {"image0": feats0, "image1": feats1}, + device=device, + r=args.repeat, + )["mean"] + results[pair_name][name].append( + 1000 / runtime if args.measure == "throughput" else runtime + ) + ax.plot( + args.num_keypoints, results[pair_name][name], label=name, marker="o" + ) + del matcher, feats0, feats1 + + if args.add_superglue: + from hloc.matchers.superglue import SuperGlue + + for name, conf in sg_configs.items(): + print("Run benchmark for:", name) + matcher = SuperGlue(conf) + matcher = matcher.eval().to(device) + for pair_name, ax in zip(inputs.keys(), axes): + image0, image1 = [x.to(device) for x in inputs[pair_name]] + runtimes = [] + for num_kpts in args.num_keypoints: + extractor.conf.max_num_keypoints = num_kpts + feats0 = extractor.extract(image0) + feats1 = extractor.extract(image1) + data = { + "image0": image0[None], + "image1": image1[None], + **{k + "0": v for k, v in feats0.items()}, + **{k + "1": v for k, v in feats1.items()}, + } + data["scores0"] = data["keypoint_scores0"] + data["scores1"] = data["keypoint_scores1"] + data["descriptors0"] = ( + data["descriptors0"].transpose(-1, -2).contiguous() + ) + data["descriptors1"] = ( + data["descriptors1"].transpose(-1, -2).contiguous() + ) + runtime = measure(matcher, data, device=device, r=args.repeat)[ + "mean" + ] + results[pair_name][name].append( + 1000 / runtime if args.measure == "throughput" else runtime + ) + ax.plot( + args.num_keypoints, results[pair_name][name], label=name, marker="o" + ) + del matcher, data, image0, image1, feats0, feats1 + + for name, runtimes in results.items(): + print_as_table(runtimes, name, args.num_keypoints) + + axes[0].legend() + fig.tight_layout() + if args.save: + plt.savefig(args.save, dpi=fig.dpi) + plt.show() diff --git a/imcui/third_party/LightGlue/lightglue/__init__.py b/imcui/third_party/LightGlue/lightglue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b84d285cf2a29e3b17c8c2c052a45f856dcf89c0 --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/__init__.py @@ -0,0 +1,7 @@ +from .aliked import ALIKED # noqa +from .disk import DISK # noqa +from .dog_hardnet import DoGHardNet # noqa +from .lightglue import LightGlue # noqa +from .sift import SIFT # noqa +from .superpoint import SuperPoint # noqa +from .utils import match_pair # noqa diff --git a/imcui/third_party/LightGlue/lightglue/aliked.py b/imcui/third_party/LightGlue/lightglue/aliked.py new file mode 100644 index 0000000000000000000000000000000000000000..1161e1fc2d0cce32583031229e8ad4bb84f9a40c --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/aliked.py @@ -0,0 +1,758 @@ +# BSD 3-Clause License + +# Copyright (c) 2022, Zhao Xiaoming +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Authors: +# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li +# Code from https://github.com/Shiaoming/ALIKED + +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +import torchvision +from kornia.color import grayscale_to_rgb +from torch import nn +from torch.nn.modules.utils import _pair +from torchvision.models import resnet + +from .utils import Extractor + + +def get_patches( + tensor: torch.Tensor, required_corners: torch.Tensor, ps: int +) -> torch.Tensor: + c, h, w = tensor.shape + corner = (required_corners - ps / 2 + 1).long() + corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps) + corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps) + offset = torch.arange(0, ps) + + kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} + x, y = torch.meshgrid(offset, offset, **kw) + patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2) + patches = patches.to(corner) + corner[None, None] + pts = patches.reshape(-1, 2) + sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]] + sampled = sampled.reshape(ps, ps, -1, c) + assert sampled.shape[:3] == patches.shape[:3] + return sampled.permute(2, 3, 0, 1) + + +def simple_nms(scores: torch.Tensor, nms_radius: int): + """Fast Non-maximum suppression to remove nearby points""" + + zeros = torch.zeros_like(scores) + max_mask = scores == torch.nn.functional.max_pool2d( + scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) + + for _ in range(2): + supp_mask = ( + torch.nn.functional.max_pool2d( + max_mask.float(), + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ) + > 0 + ) + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == torch.nn.functional.max_pool2d( + supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +class DKD(nn.Module): + def __init__( + self, + radius: int = 2, + top_k: int = 0, + scores_th: float = 0.2, + n_limit: int = 20000, + ): + """ + Args: + radius: soft detection radius, kernel size is (2 * radius + 1) + top_k: top_k > 0: return top k keypoints + scores_th: top_k <= 0 threshold mode: + scores_th > 0: return keypoints with scores>scores_th + else: return keypoints with scores > scores.mean() + n_limit: max number of keypoint in threshold mode + """ + super().__init__() + self.radius = radius + self.top_k = top_k + self.scores_th = scores_th + self.n_limit = n_limit + self.kernel_size = 2 * self.radius + 1 + self.temperature = 0.1 # tuned temperature + self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius) + # local xy grid + x = torch.linspace(-self.radius, self.radius, self.kernel_size) + # (kernel_size*kernel_size) x 2 : (w,h) + kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} + self.hw_grid = ( + torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]] + ) + + def forward( + self, + scores_map: torch.Tensor, + sub_pixel: bool = True, + image_size: Optional[torch.Tensor] = None, + ): + """ + :param scores_map: Bx1xHxW + :param descriptor_map: BxCxHxW + :param sub_pixel: whether to use sub-pixel keypoint detection + :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1 + """ + b, c, h, w = scores_map.shape + scores_nograd = scores_map.detach() + nms_scores = simple_nms(scores_nograd, self.radius) + + # remove border + nms_scores[:, :, : self.radius, :] = 0 + nms_scores[:, :, :, : self.radius] = 0 + if image_size is not None: + for i in range(scores_map.shape[0]): + w, h = image_size[i].long() + nms_scores[i, :, h.item() - self.radius :, :] = 0 + nms_scores[i, :, :, w.item() - self.radius :] = 0 + else: + nms_scores[:, :, -self.radius :, :] = 0 + nms_scores[:, :, :, -self.radius :] = 0 + + # detect keypoints without grad + if self.top_k > 0: + topk = torch.topk(nms_scores.view(b, -1), self.top_k) + indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k + else: + if self.scores_th > 0: + masks = nms_scores > self.scores_th + if masks.sum() == 0: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + else: + th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th + masks = nms_scores > th.reshape(b, 1, 1, 1) + masks = masks.reshape(b, -1) + + indices_keypoints = [] # list, B x (any size) + scores_view = scores_nograd.reshape(b, -1) + for mask, scores in zip(masks, scores_view): + indices = mask.nonzero()[:, 0] + if len(indices) > self.n_limit: + kpts_sc = scores[indices] + sort_idx = kpts_sc.sort(descending=True)[1] + sel_idx = sort_idx[: self.n_limit] + indices = indices[sel_idx] + indices_keypoints.append(indices) + + wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device) + + keypoints = [] + scoredispersitys = [] + kptscores = [] + if sub_pixel: + # detect soft keypoints with grad backpropagation + patches = self.unfold(scores_map) # B x (kernel**2) x (H*W) + self.hw_grid = self.hw_grid.to(scores_map) # to device + for b_idx in range(b): + patch = patches[b_idx].t() # (H*W) x (kernel**2) + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M + patch_scores = patch[indices_kpt] # M x (kernel**2) + keypoints_xy_nms = torch.stack( + [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")], + dim=1, + ) # Mx2 + + # max is detached to prevent undesired backprop loops in the graph + max_v = patch_scores.max(dim=1).values.detach()[:, None] + x_exp = ( + (patch_scores - max_v) / self.temperature + ).exp() # M * (kernel**2), in [0, 1] + + # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} } + xy_residual = ( + x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] + ) # Soft-argmax, Mx2 + + hw_grid_dist2 = ( + torch.norm( + (self.hw_grid[None, :, :] - xy_residual[:, None, :]) + / self.radius, + dim=-1, + ) + ** 2 + ) + scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1) + + # compute result keypoints + keypoints_xy = keypoints_xy_nms + xy_residual + keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1) + + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN + + keypoints.append(keypoints_xy) + scoredispersitys.append(scoredispersity) + kptscores.append(kptscore) + else: + for b_idx in range(b): + indices_kpt = indices_keypoints[ + b_idx + ] # one dimension vector, say its size is M + # To avoid warning: UserWarning: __floordiv__ is deprecated + keypoints_xy_nms = torch.stack( + [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")], + dim=1, + ) # Mx2 + keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1) + kptscore = torch.nn.functional.grid_sample( + scores_map[b_idx].unsqueeze(0), + keypoints_xy.view(1, 1, -1, 2), + mode="bilinear", + align_corners=True, + )[ + 0, 0, 0, : + ] # CxN + keypoints.append(keypoints_xy) + scoredispersitys.append(kptscore) # for jit.script compatability + kptscores.append(kptscore) + + return keypoints, scoredispersitys, kptscores + + +class InputPadder(object): + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, h: int, w: int, divis_by: int = 8): + self.ht = h + self.wd = w + pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by + pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + + def pad(self, x: torch.Tensor): + assert x.ndim == 4 + return F.pad(x, self._pad, mode="replicate") + + def unpad(self, x: torch.Tensor): + assert x.ndim == 4 + ht = x.shape[-2] + wd = x.shape[-1] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +class DeformableConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + mask=False, + ): + super(DeformableConv2d, self).__init__() + + self.padding = padding + self.mask = mask + + self.channel_num = ( + 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size + ) + self.offset_conv = nn.Conv2d( + in_channels, + self.channel_num, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=True, + ) + + self.regular_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + bias=bias, + ) + + def forward(self, x): + h, w = x.shape[2:] + max_offset = max(h, w) / 4.0 + + out = self.offset_conv(x) + if self.mask: + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + else: + offset = out + mask = None + offset = offset.clamp(-max_offset, max_offset) + x = torchvision.ops.deform_conv2d( + input=x, + offset=offset, + weight=self.regular_conv.weight, + bias=self.regular_conv.bias, + padding=self.padding, + mask=mask, + ) + return x + + +def get_conv( + inplanes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_type="conv", + mask=False, +): + if conv_type == "conv": + conv = nn.Conv2d( + inplanes, + planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + elif conv_type == "dcn": + conv = DeformableConv2d( + inplanes, + planes, + kernel_size=kernel_size, + stride=stride, + padding=_pair(padding), + bias=bias, + mask=mask, + ) + else: + raise TypeError + return conv + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + conv_type: str = "conv", + mask: bool = False, + ): + super().__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = get_conv( + in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn1 = norm_layer(out_channels) + self.conv2 = get_conv( + out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn2 = norm_layer(out_channels) + + def forward(self, x): + x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W + x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W + return x + + +# modified based on torchvision\models\resnet.py#27->BasicBlock +class ResBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + gate: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + conv_type: str = "conv", + mask: bool = False, + ) -> None: + super(ResBlock, self).__init__() + if gate is None: + self.gate = nn.ReLU(inplace=True) + else: + self.gate = gate + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("ResBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in ResBlock") + # Both self.conv1 and self.downsample layers + # downsample the input when stride != 1 + self.conv1 = get_conv( + inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn1 = norm_layer(planes) + self.conv2 = get_conv( + planes, planes, kernel_size=3, conv_type=conv_type, mask=mask + ) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.gate(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.gate(out) + + return out + + +class SDDH(nn.Module): + def __init__( + self, + dims: int, + kernel_size: int = 3, + n_pos: int = 8, + gate=nn.ReLU(), + conv2D=False, + mask=False, + ): + super(SDDH, self).__init__() + self.kernel_size = kernel_size + self.n_pos = n_pos + self.conv2D = conv2D + self.mask = mask + + self.get_patches_func = get_patches + + # estimate offsets + self.channel_num = 3 * n_pos if mask else 2 * n_pos + self.offset_conv = nn.Sequential( + nn.Conv2d( + dims, + self.channel_num, + kernel_size=kernel_size, + stride=1, + padding=0, + bias=True, + ), + gate, + nn.Conv2d( + self.channel_num, + self.channel_num, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), + ) + + # sampled feature conv + self.sf_conv = nn.Conv2d( + dims, dims, kernel_size=1, stride=1, padding=0, bias=False + ) + + # convM + if not conv2D: + # deformable desc weights + agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims)) + self.register_parameter("agg_weights", agg_weights) + else: + self.convM = nn.Conv2d( + dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False + ) + + def forward(self, x, keypoints): + # x: [B,C,H,W] + # keypoints: list, [[N_kpts,2], ...] (w,h) + b, c, h, w = x.shape + wh = torch.tensor([[w - 1, h - 1]], device=x.device) + max_offset = max(h, w) / 4.0 + + offsets = [] + descriptors = [] + # get offsets for each keypoint + for ib in range(b): + xi, kptsi = x[ib], keypoints[ib] + kptsi_wh = (kptsi / 2 + 0.5) * wh + N_kpts = len(kptsi) + + if self.kernel_size > 1: + patch = self.get_patches_func( + xi, kptsi_wh.long(), self.kernel_size + ) # [N_kpts, C, K, K] + else: + kptsi_wh_long = kptsi_wh.long() + patch = ( + xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]] + .permute(1, 0) + .reshape(N_kpts, c, 1, 1) + ) + + offset = self.offset_conv(patch).clamp( + -max_offset, max_offset + ) # [N_kpts, 2*n_pos, 1, 1] + if self.mask: + offset = ( + offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1) + ) # [N_kpts, n_pos, 3] + offset = offset[:, :, :-1] # [N_kpts, n_pos, 2] + mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos] + else: + offset = ( + offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1) + ) # [N_kpts, n_pos, 2] + offsets.append(offset) # for visualization + + # get sample positions + pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2] + pos = 2.0 * pos / wh[None] - 1 + pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2) + + # sample features + features = F.grid_sample( + xi.unsqueeze(0), pos, mode="bilinear", align_corners=True + ) # [1,C,(N_kpts*n_pos),1] + features = features.reshape(c, N_kpts, self.n_pos, 1).permute( + 1, 0, 2, 3 + ) # [N_kpts, C, n_pos, 1] + if self.mask: + features = torch.einsum("ncpo,np->ncpo", features, mask_weight) + + features = torch.selu_(self.sf_conv(features)).squeeze( + -1 + ) # [N_kpts, C, n_pos] + # convM + if not self.conv2D: + descs = torch.einsum( + "ncp,pcd->nd", features, self.agg_weights + ) # [N_kpts, C] + else: + features = features.reshape(N_kpts, -1)[ + :, :, None, None + ] # [N_kpts, C*n_pos, 1, 1] + descs = self.convM(features).squeeze() # [N_kpts, C] + + # normalize + descs = F.normalize(descs, p=2.0, dim=1) + descriptors.append(descs) + + return descriptors, offsets + + +class ALIKED(Extractor): + default_conf = { + "model_name": "aliked-n16", + "max_num_keypoints": -1, + "detection_threshold": 0.2, + "nms_radius": 2, + } + + checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth" + + n_limit_max = 20000 + + # c1, c2, c3, c4, dim, K, M + cfgs = { + "aliked-t16": [8, 16, 32, 64, 64, 3, 16], + "aliked-n16": [16, 32, 64, 128, 128, 3, 16], + "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16], + "aliked-n32": [16, 32, 64, 128, 128, 3, 32], + } + preprocess_conf = { + "resize": 1024, + } + + required_data_keys = ["image"] + + def __init__(self, **conf): + super().__init__(**conf) # Update with default configuration. + conf = self.conf + c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name] + conv_types = ["conv", "conv", "dcn", "dcn"] + conv2D = False + mask = False + + # build model + self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4) + self.norm = nn.BatchNorm2d + self.gate = nn.SELU(inplace=True) + self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0]) + self.block2 = self.get_resblock(c1, c2, conv_types[1], mask) + self.block3 = self.get_resblock(c2, c3, conv_types[2], mask) + self.block4 = self.get_resblock(c3, c4, conv_types[3], mask) + + self.conv1 = resnet.conv1x1(c1, dim // 4) + self.conv2 = resnet.conv1x1(c2, dim // 4) + self.conv3 = resnet.conv1x1(c3, dim // 4) + self.conv4 = resnet.conv1x1(dim, dim // 4) + self.upsample2 = nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True + ) + self.upsample4 = nn.Upsample( + scale_factor=4, mode="bilinear", align_corners=True + ) + self.upsample8 = nn.Upsample( + scale_factor=8, mode="bilinear", align_corners=True + ) + self.upsample32 = nn.Upsample( + scale_factor=32, mode="bilinear", align_corners=True + ) + self.score_head = nn.Sequential( + resnet.conv1x1(dim, 8), + self.gate, + resnet.conv3x3(8, 4), + self.gate, + resnet.conv3x3(4, 4), + self.gate, + resnet.conv3x3(4, 1), + ) + self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask) + self.dkd = DKD( + radius=conf.nms_radius, + top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints, + scores_th=conf.detection_threshold, + n_limit=conf.max_num_keypoints + if conf.max_num_keypoints > 0 + else self.n_limit_max, + ) + + state_dict = torch.hub.load_state_dict_from_url( + self.checkpoint_url.format(conf.model_name), map_location="cpu" + ) + self.load_state_dict(state_dict, strict=True) + + def get_resblock(self, c_in, c_out, conv_type, mask): + return ResBlock( + c_in, + c_out, + 1, + nn.Conv2d(c_in, c_out, 1), + gate=self.gate, + norm_layer=self.norm, + conv_type=conv_type, + mask=mask, + ) + + def extract_dense_map(self, image): + # Pads images such that dimensions are divisible by + div_by = 2**5 + padder = InputPadder(image.shape[-2], image.shape[-1], div_by) + image = padder.pad(image) + + # ================================== feature encoder + x1 = self.block1(image) # B x c1 x H x W + x2 = self.pool2(x1) + x2 = self.block2(x2) # B x c2 x H/2 x W/2 + x3 = self.pool4(x2) + x3 = self.block3(x3) # B x c3 x H/8 x W/8 + x4 = self.pool4(x3) + x4 = self.block4(x4) # B x dim x H/32 x W/32 + # ================================== feature aggregation + x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W + x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2 + x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8 + x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32 + x2_up = self.upsample2(x2) # B x dim//4 x H x W + x3_up = self.upsample8(x3) # B x dim//4 x H x W + x4_up = self.upsample32(x4) # B x dim//4 x H x W + x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) + # ================================== score head + score_map = torch.sigmoid(self.score_head(x1234)) + feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1) + + # Unpads images + feature_map = padder.unpad(feature_map) + score_map = padder.unpad(score_map) + + return feature_map, score_map + + def forward(self, data: dict) -> dict: + image = data["image"] + if image.shape[1] == 1: + image = grayscale_to_rgb(image) + feature_map, score_map = self.extract_dense_map(image) + keypoints, kptscores, scoredispersitys = self.dkd( + score_map, image_size=data.get("image_size") + ) + descriptors, offsets = self.desc_head(feature_map, keypoints) + + _, _, h, w = image.shape + wh = torch.tensor([w - 1, h - 1], device=image.device) + # no padding required + # we can set detection_threshold=-1 and conf.max_num_keypoints > 0 + return { + "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2 + "descriptors": torch.stack(descriptors), # B x N x D + "keypoint_scores": torch.stack(kptscores), # B x N + } diff --git a/imcui/third_party/LightGlue/lightglue/disk.py b/imcui/third_party/LightGlue/lightglue/disk.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb2195fe2f95c32959b5be4b09ad91bb51a35d5 --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/disk.py @@ -0,0 +1,55 @@ +import kornia +import torch + +from .utils import Extractor + + +class DISK(Extractor): + default_conf = { + "weights": "depth", + "max_num_keypoints": None, + "desc_dim": 128, + "nms_window_size": 5, + "detection_threshold": 0.0, + "pad_if_not_divisible": True, + } + + preprocess_conf = { + "resize": 1024, + "grayscale": False, + } + + required_data_keys = ["image"] + + def __init__(self, **conf) -> None: + super().__init__(**conf) # Update with default configuration. + self.model = kornia.feature.DISK.from_pretrained(self.conf.weights) + + def forward(self, data: dict) -> dict: + """Compute keypoints, scores, descriptors for image""" + for key in self.required_data_keys: + assert key in data, f"Missing key {key} in data" + image = data["image"] + if image.shape[1] == 1: + image = kornia.color.grayscale_to_rgb(image) + features = self.model( + image, + n=self.conf.max_num_keypoints, + window_size=self.conf.nms_window_size, + score_threshold=self.conf.detection_threshold, + pad_if_not_divisible=self.conf.pad_if_not_divisible, + ) + keypoints = [f.keypoints for f in features] + scores = [f.detection_scores for f in features] + descriptors = [f.descriptors for f in features] + del features + + keypoints = torch.stack(keypoints, 0) + scores = torch.stack(scores, 0) + descriptors = torch.stack(descriptors, 0) + + return { + "keypoints": keypoints.to(image).contiguous(), + "keypoint_scores": scores.to(image).contiguous(), + "descriptors": descriptors.to(image).contiguous(), + } diff --git a/imcui/third_party/LightGlue/lightglue/dog_hardnet.py b/imcui/third_party/LightGlue/lightglue/dog_hardnet.py new file mode 100644 index 0000000000000000000000000000000000000000..cce307ae1f11e2066312fd44ecac8884d1de3358 --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/dog_hardnet.py @@ -0,0 +1,41 @@ +import torch +from kornia.color import rgb_to_grayscale +from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori + +from .sift import SIFT + + +class DoGHardNet(SIFT): + required_data_keys = ["image"] + + def __init__(self, **conf): + super().__init__(**conf) + self.laf_desc = LAFDescriptor(HardNet(True)).eval() + + def forward(self, data: dict) -> dict: + image = data["image"] + if image.shape[1] == 3: + image = rgb_to_grayscale(image) + device = image.device + self.laf_desc = self.laf_desc.to(device) + self.laf_desc.descriptor = self.laf_desc.descriptor.eval() + pred = [] + if "image_size" in data.keys(): + im_size = data.get("image_size").long() + else: + im_size = None + for k in range(len(image)): + img = image[k] + if im_size is not None: + w, h = data["image_size"][k] + img = img[:, : h.to(torch.int32), : w.to(torch.int32)] + p = self.extract_single_image(img) + lafs = laf_from_center_scale_ori( + p["keypoints"].reshape(1, -1, 2), + 6.0 * p["scales"].reshape(1, -1, 1, 1), + torch.rad2deg(p["oris"]).reshape(1, -1, 1), + ).to(device) + p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128) + pred.append(p) + pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} + return pred diff --git a/imcui/third_party/LightGlue/lightglue/lightglue.py b/imcui/third_party/LightGlue/lightglue/lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..14e6a61e25764e1a4258e94363c30211bcbc7c44 --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/lightglue.py @@ -0,0 +1,655 @@ +import warnings +from pathlib import Path +from types import SimpleNamespace +from typing import Callable, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +try: + from flash_attn.modules.mha import FlashCrossAttention +except ModuleNotFoundError: + FlashCrossAttention = None + +if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"): + FLASH_AVAILABLE = True +else: + FLASH_AVAILABLE = False + +torch.backends.cudnn.deterministic = True + + +@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) +def normalize_keypoints( + kpts: torch.Tensor, size: Optional[torch.Tensor] = None +) -> torch.Tensor: + if size is None: + size = 1 + kpts.max(-2).values - kpts.min(-2).values + elif not isinstance(size, torch.Tensor): + size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype) + size = size.to(kpts) + shift = size / 2 + scale = size.max(-1).values / 2 + kpts = (kpts - shift[..., None, :]) / scale[..., None, None] + return kpts + + +def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]: + if length <= x.shape[-2]: + return x, torch.ones_like(x[..., :1], dtype=torch.bool) + pad = torch.ones( + *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype + ) + y = torch.cat([x, pad], dim=-2) + mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device) + mask[..., : x.shape[-2], :] = True + return y, mask + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """encode position vector""" + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class TokenConfidence(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid()) + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """get confidence tokens""" + return ( + self.token(desc0.detach()).squeeze(-1), + self.token(desc1.detach()).squeeze(-1), + ) + + +class Attention(nn.Module): + def __init__(self, allow_flash: bool) -> None: + super().__init__() + if allow_flash and not FLASH_AVAILABLE: + warnings.warn( + "FlashAttention is not available. For optimal speed, " + "consider installing torch >= 2.0 or flash-attn.", + stacklevel=2, + ) + self.enable_flash = allow_flash and FLASH_AVAILABLE + self.has_sdp = hasattr(F, "scaled_dot_product_attention") + if allow_flash and FlashCrossAttention: + self.flash_ = FlashCrossAttention() + if self.has_sdp: + torch.backends.cuda.enable_flash_sdp(allow_flash) + + def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if q.shape[-2] == 0 or k.shape[-2] == 0: + return q.new_zeros((*q.shape[:-1], v.shape[-1])) + if self.enable_flash and q.device.type == "cuda": + # use torch 2.0 scaled_dot_product_attention with flash + if self.has_sdp: + args = [x.half().contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype) + return v if mask is None else v.nan_to_num() + else: + assert mask is None + q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]] + m = self.flash_(q.half(), torch.stack([k, v], 2).half()) + return m.transpose(-2, -3).to(q.dtype).clone() + elif self.has_sdp: + args = [x.contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask) + return v if mask is None else v.nan_to_num() + else: + s = q.shape[-1] ** -0.5 + sim = torch.einsum("...id,...jd->...ij", q, k) * s + if mask is not None: + sim.masked_fill(~mask, -float("inf")) + attn = F.softmax(sim, -1) + return torch.einsum("...ij,...jd->...id", attn, v) + + +class SelfBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0 + self.head_dim = self.embed_dim // num_heads + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + self.inner_attn = Attention(flash) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + + def forward( + self, + x: torch.Tensor, + encoding: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.Wqkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + context = self.inner_attn(q, k, v, mask=mask) + message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2)) + return x + self.ffn(torch.cat([x, message], -1)) + + +class CrossBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.heads = num_heads + dim_head = embed_dim // num_heads + self.scale = dim_head**-0.5 + inner_dim = dim_head * num_heads + self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + if flash and FLASH_AVAILABLE: + self.flash = Attention(True) + else: + self.flash = None + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward( + self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> List[torch.Tensor]: + qk0, qk1 = self.map_(self.to_qk, x0, x1) + v0, v1 = self.map_(self.to_v, x0, x1) + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1), + ) + if self.flash is not None and qk0.device.type == "cuda": + m0 = self.flash(qk0, qk1, v1, mask) + m1 = self.flash( + qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None + ) + else: + qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 + sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1) + if mask is not None: + sim = sim.masked_fill(~mask, -float("inf")) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) + m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) + if mask is not None: + m0, m1 = m0.nan_to_num(), m1.nan_to_num() + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) + m0, m1 = self.map_(self.to_out, m0, m1) + x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) + x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) + return x0, x1 + + +class TransformerLayer(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.self_attn = SelfBlock(*args, **kwargs) + self.cross_attn = CrossBlock(*args, **kwargs) + + def forward( + self, + desc0, + desc1, + encoding0, + encoding1, + mask0: Optional[torch.Tensor] = None, + mask1: Optional[torch.Tensor] = None, + ): + if mask0 is not None and mask1 is not None: + return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1) + else: + desc0 = self.self_attn(desc0, encoding0) + desc1 = self.self_attn(desc1, encoding1) + return self.cross_attn(desc0, desc1) + + # This part is compiled and allows padding inputs + def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1): + mask = mask0 & mask1.transpose(-1, -2) + mask0 = mask0 & mask0.transpose(-1, -2) + mask1 = mask1 & mask1.transpose(-1, -2) + desc0 = self.self_attn(desc0, encoding0, mask0) + desc1 = self.self_attn(desc1, encoding1, mask1) + return self.cross_attn(desc0, desc1, mask) + + +def sigmoid_log_double_softmax( + sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + b, m, n = sim.shape + certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2) + scores0 = F.log_softmax(sim, 2) + scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = sim.new_full((b, m + 1, n + 1), 0) + scores[:, :m, :n] = scores0 + scores1 + certainties + scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) + scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) + return scores + + +class MatchAssignment(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + self.matchability = nn.Linear(dim, 1, bias=True) + self.final_proj = nn.Linear(dim, dim, bias=True) + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """build assignment matrix from descriptors""" + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) + _, _, d = mdesc0.shape + mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25 + sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1) + z0 = self.matchability(desc0) + z1 = self.matchability(desc1) + scores = sigmoid_log_double_softmax(sim, z0, z1) + return scores, sim + + def get_matchability(self, desc: torch.Tensor): + return torch.sigmoid(self.matchability(desc)).squeeze(-1) + + +def filter_matches(scores: torch.Tensor, th: float): + """obtain matches from a log assignment matrix [Bx M+1 x N+1]""" + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + m0, m1 = max0.indices, max1.indices + indices0 = torch.arange(m0.shape[1], device=m0.device)[None] + indices1 = torch.arange(m1.shape[1], device=m1.device)[None] + mutual0 = indices0 == m1.gather(1, m0) + mutual1 = indices1 == m0.gather(1, m1) + max0_exp = max0.values.exp() + zero = max0_exp.new_tensor(0) + mscores0 = torch.where(mutual0, max0_exp, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) + valid0 = mutual0 & (mscores0 > th) + valid1 = mutual1 & valid0.gather(1, m1) + m0 = torch.where(valid0, m0, -1) + m1 = torch.where(valid1, m1, -1) + return m0, m1, mscores0, mscores1 + + +class LightGlue(nn.Module): + default_conf = { + "name": "lightglue", # just for interfacing + "input_dim": 256, # input descriptor dimension (autoselected from weights) + "descriptor_dim": 256, + "add_scale_ori": False, + "n_layers": 9, + "num_heads": 4, + "flash": True, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "depth_confidence": 0.95, # early stopping, disable with -1 + "width_confidence": 0.99, # point pruning, disable with -1 + "filter_threshold": 0.1, # match threshold + "weights": None, + } + + # Point pruning involves an overhead (gather). + # Therefore, we only activate it if there are enough keypoints. + pruning_keypoint_thresholds = { + "cpu": -1, + "mps": -1, + "cuda": 1024, + "flash": 1536, + } + + required_data_keys = ["image0", "image1"] + + version = "v0.1_arxiv" + url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth" + + features = { + "superpoint": { + "weights": "superpoint_lightglue", + "input_dim": 256, + }, + "disk": { + "weights": "disk_lightglue", + "input_dim": 128, + }, + "aliked": { + "weights": "aliked_lightglue", + "input_dim": 128, + }, + "sift": { + "weights": "sift_lightglue", + "input_dim": 128, + "add_scale_ori": True, + }, + "doghardnet": { + "weights": "doghardnet_lightglue", + "input_dim": 128, + "add_scale_ori": True, + }, + } + + def __init__(self, features="superpoint", **conf) -> None: + super().__init__() + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + if features is not None: + if features not in self.features: + raise ValueError( + f"Unsupported features: {features} not in " + f"{{{','.join(self.features)}}}" + ) + for k, v in self.features[features].items(): + setattr(conf, k, v) + + if conf.input_dim != conf.descriptor_dim: + self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) + else: + self.input_proj = nn.Identity() + + head_dim = conf.descriptor_dim // conf.num_heads + self.posenc = LearnableFourierPositionalEncoding( + 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim + ) + + h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim + + self.transformers = nn.ModuleList( + [TransformerLayer(d, h, conf.flash) for _ in range(n)] + ) + + self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) + self.token_confidence = nn.ModuleList( + [TokenConfidence(d) for _ in range(n - 1)] + ) + self.register_buffer( + "confidence_thresholds", + torch.Tensor( + [self.confidence_threshold(i) for i in range(self.conf.n_layers)] + ), + ) + + state_dict = None + if features is not None: + fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth" + state_dict = torch.hub.load_state_dict_from_url( + self.url.format(self.version, features), file_name=fname + ) + self.load_state_dict(state_dict, strict=False) + elif conf.weights is not None: + path = Path(__file__).parent + path = path / "weights/{}.pth".format(self.conf.weights) + state_dict = torch.load(str(path), map_location="cpu") + + if state_dict: + # rename old state dict entries + for i in range(self.conf.n_layers): + pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" + state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + self.load_state_dict(state_dict, strict=False) + + # static lengths LightGlue is compiled for (only used with torch.compile) + self.static_lengths = None + + def compile( + self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536] + ): + if self.conf.width_confidence != -1: + warnings.warn( + "Point pruning is partially disabled for compiled forward.", + stacklevel=2, + ) + + torch._inductor.cudagraph_mark_step_begin() + for i in range(self.conf.n_layers): + self.transformers[i].masked_forward = torch.compile( + self.transformers[i].masked_forward, mode=mode, fullgraph=True + ) + + self.static_lengths = static_lengths + + def forward(self, data: dict) -> dict: + """ + Match keypoints and descriptors between two images + + Input (dict): + image0: dict + keypoints: [B x M x 2] + descriptors: [B x M x D] + image: [B x C x H x W] or image_size: [B x 2] + image1: dict + keypoints: [B x N x 2] + descriptors: [B x N x D] + image: [B x C x H x W] or image_size: [B x 2] + Output (dict): + matches0: [B x M] + matching_scores0: [B x M] + matches1: [B x N] + matching_scores1: [B x N] + matches: List[[Si x 2]] + scores: List[[Si]] + stop: int + prune0: [B x M] + prune1: [B x N] + """ + with torch.autocast(enabled=self.conf.mp, device_type="cuda"): + return self._forward(data) + + def _forward(self, data: dict) -> dict: + for key in self.required_data_keys: + assert key in data, f"Missing key {key} in data" + data0, data1 = data["image0"], data["image1"] + kpts0, kpts1 = data0["keypoints"], data1["keypoints"] + b, m, _ = kpts0.shape + b, n, _ = kpts1.shape + device = kpts0.device + size0, size1 = data0.get("image_size"), data1.get("image_size") + kpts0 = normalize_keypoints(kpts0, size0).clone() + kpts1 = normalize_keypoints(kpts1, size1).clone() + + if self.conf.add_scale_ori: + kpts0 = torch.cat( + [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1 + ) + kpts1 = torch.cat( + [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1 + ) + desc0 = data0["descriptors"].detach().contiguous() + desc1 = data1["descriptors"].detach().contiguous() + + assert desc0.shape[-1] == self.conf.input_dim + assert desc1.shape[-1] == self.conf.input_dim + + if torch.is_autocast_enabled(): + desc0 = desc0.half() + desc1 = desc1.half() + + mask0, mask1 = None, None + c = max(m, n) + do_compile = self.static_lengths and c <= max(self.static_lengths) + if do_compile: + kn = min([k for k in self.static_lengths if k >= c]) + desc0, mask0 = pad_to_length(desc0, kn) + desc1, mask1 = pad_to_length(desc1, kn) + kpts0, _ = pad_to_length(kpts0, kn) + kpts1, _ = pad_to_length(kpts1, kn) + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + # cache positional embeddings + encoding0 = self.posenc(kpts0) + encoding1 = self.posenc(kpts1) + + # GNN + final_proj + assignment + do_early_stop = self.conf.depth_confidence > 0 + do_point_pruning = self.conf.width_confidence > 0 and not do_compile + pruning_th = self.pruning_min_kpts(device) + if do_point_pruning: + ind0 = torch.arange(0, m, device=device)[None] + ind1 = torch.arange(0, n, device=device)[None] + # We store the index of the layer at which pruning is detected. + prune0 = torch.ones_like(ind0) + prune1 = torch.ones_like(ind1) + token0, token1 = None, None + for i in range(self.conf.n_layers): + if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints + break + desc0, desc1 = self.transformers[i]( + desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1 + ) + if i == self.conf.n_layers - 1: + continue # no early stopping or adaptive width at last layer + + if do_early_stop: + token0, token1 = self.token_confidence[i](desc0, desc1) + if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n): + break + if do_point_pruning and desc0.shape[-2] > pruning_th: + scores0 = self.log_assignment[i].get_matchability(desc0) + prunemask0 = self.get_pruning_mask(token0, scores0, i) + keep0 = torch.where(prunemask0)[1] + ind0 = ind0.index_select(1, keep0) + desc0 = desc0.index_select(1, keep0) + encoding0 = encoding0.index_select(-2, keep0) + prune0[:, ind0] += 1 + if do_point_pruning and desc1.shape[-2] > pruning_th: + scores1 = self.log_assignment[i].get_matchability(desc1) + prunemask1 = self.get_pruning_mask(token1, scores1, i) + keep1 = torch.where(prunemask1)[1] + ind1 = ind1.index_select(1, keep1) + desc1 = desc1.index_select(1, keep1) + encoding1 = encoding1.index_select(-2, keep1) + prune1[:, ind1] += 1 + + if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints + m0 = desc0.new_full((b, m), -1, dtype=torch.long) + m1 = desc1.new_full((b, n), -1, dtype=torch.long) + mscores0 = desc0.new_zeros((b, m)) + mscores1 = desc1.new_zeros((b, n)) + matches = desc0.new_empty((b, 0, 2), dtype=torch.long) + mscores = desc0.new_empty((b, 0)) + if not do_point_pruning: + prune0 = torch.ones_like(mscores0) * self.conf.n_layers + prune1 = torch.ones_like(mscores1) * self.conf.n_layers + return { + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "stop": i + 1, + "matches": matches, + "scores": mscores, + "prune0": prune0, + "prune1": prune1, + } + + desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding + scores, _ = self.log_assignment[i](desc0, desc1) + m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) + matches, mscores = [], [] + for k in range(b): + valid = m0[k] > -1 + m_indices_0 = torch.where(valid)[0] + m_indices_1 = m0[k][valid] + if do_point_pruning: + m_indices_0 = ind0[k, m_indices_0] + m_indices_1 = ind1[k, m_indices_1] + matches.append(torch.stack([m_indices_0, m_indices_1], -1)) + mscores.append(mscores0[k][valid]) + + # TODO: Remove when hloc switches to the compact format. + if do_point_pruning: + m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype) + m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype) + m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) + m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) + mscores0_ = torch.zeros((b, m), device=mscores0.device) + mscores1_ = torch.zeros((b, n), device=mscores1.device) + mscores0_[:, ind0] = mscores0 + mscores1_[:, ind1] = mscores1 + m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_ + else: + prune0 = torch.ones_like(mscores0) * self.conf.n_layers + prune1 = torch.ones_like(mscores1) * self.conf.n_layers + + return { + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "stop": i + 1, + "matches": matches, + "scores": mscores, + "prune0": prune0, + "prune1": prune1, + } + + def confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers) + return np.clip(threshold, 0, 1) + + def get_pruning_mask( + self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int + ) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.conf.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self.confidence_thresholds[layer_index] + return keep + + def check_if_stop( + self, + confidences0: torch.Tensor, + confidences1: torch.Tensor, + layer_index: int, + num_points: int, + ) -> torch.Tensor: + """evaluate stopping condition""" + confidences = torch.cat([confidences0, confidences1], -1) + threshold = self.confidence_thresholds[layer_index] + ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points + return ratio_confident > self.conf.depth_confidence + + def pruning_min_kpts(self, device: torch.device): + if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda": + return self.pruning_keypoint_thresholds["flash"] + else: + return self.pruning_keypoint_thresholds[device.type] diff --git a/imcui/third_party/LightGlue/lightglue/sift.py b/imcui/third_party/LightGlue/lightglue/sift.py new file mode 100644 index 0000000000000000000000000000000000000000..802fc1c2eb9ee852691e0e4dd67455f822f8405f --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/sift.py @@ -0,0 +1,216 @@ +import warnings + +import cv2 +import numpy as np +import torch +from kornia.color import rgb_to_grayscale +from packaging import version + +try: + import pycolmap +except ImportError: + pycolmap = None + +from .utils import Extractor + + +def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None): + h, w = image_shape + ij = np.round(points - 0.5).astype(int).T[::-1] + + # Remove duplicate points (identical coordinates). + # Pick highest scale or score + s = scales if scores is None else scores + buffer = np.zeros((h, w)) + np.maximum.at(buffer, tuple(ij), s) + keep = np.where(buffer[tuple(ij)] == s)[0] + + # Pick lowest angle (arbitrary). + ij = ij[:, keep] + buffer[:] = np.inf + o_abs = np.abs(angles[keep]) + np.minimum.at(buffer, tuple(ij), o_abs) + mask = buffer[tuple(ij)] == o_abs + ij = ij[:, mask] + keep = keep[mask] + + if nms_radius > 0: + # Apply NMS on the remaining points + buffer[:] = 0 + buffer[tuple(ij)] = s[keep] # scores or scale + + local_max = torch.nn.functional.max_pool2d( + torch.from_numpy(buffer).unsqueeze(0), + kernel_size=nms_radius * 2 + 1, + stride=1, + padding=nms_radius, + ).squeeze(0) + is_local_max = buffer == local_max.numpy() + keep = keep[is_local_max[tuple(ij)]] + return keep + + +def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor: + x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps) + x.clip_(min=eps).sqrt_() + return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps) + + +def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray: + """ + Detect keypoints using OpenCV Detector. + Optionally, perform description. + Args: + features: OpenCV based keypoints detector and descriptor + image: Grayscale image of uint8 data type + Returns: + keypoints: 1D array of detected cv2.KeyPoint + scores: 1D array of responses + descriptors: 1D array of descriptors + """ + detections, descriptors = features.detectAndCompute(image, None) + points = np.array([k.pt for k in detections], dtype=np.float32) + scores = np.array([k.response for k in detections], dtype=np.float32) + scales = np.array([k.size for k in detections], dtype=np.float32) + angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32)) + return points, scores, scales, angles, descriptors + + +class SIFT(Extractor): + default_conf = { + "rootsift": True, + "nms_radius": 0, # None to disable filtering entirely. + "max_num_keypoints": 4096, + "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda} + "detection_threshold": 0.0066667, # from COLMAP + "edge_threshold": 10, + "first_octave": -1, # only used by pycolmap, the default of COLMAP + "num_octaves": 4, + } + + preprocess_conf = { + "resize": 1024, + } + + required_data_keys = ["image"] + + def __init__(self, **conf): + super().__init__(**conf) # Update with default configuration. + backend = self.conf.backend + if backend.startswith("pycolmap"): + if pycolmap is None: + raise ImportError( + "Cannot find module pycolmap: install it with pip" + "or use backend=opencv." + ) + options = { + "peak_threshold": self.conf.detection_threshold, + "edge_threshold": self.conf.edge_threshold, + "first_octave": self.conf.first_octave, + "num_octaves": self.conf.num_octaves, + "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy. + } + device = ( + "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "") + ) + if ( + backend == "pycolmap_cpu" or not pycolmap.has_cuda + ) and pycolmap.__version__ < "0.5.0": + warnings.warn( + "The pycolmap CPU SIFT is buggy in version < 0.5.0, " + "consider upgrading pycolmap or use the CUDA version.", + stacklevel=1, + ) + else: + options["max_num_features"] = self.conf.max_num_keypoints + self.sift = pycolmap.Sift(options=options, device=device) + elif backend == "opencv": + self.sift = cv2.SIFT_create( + contrastThreshold=self.conf.detection_threshold, + nfeatures=self.conf.max_num_keypoints, + edgeThreshold=self.conf.edge_threshold, + nOctaveLayers=self.conf.num_octaves, + ) + else: + backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"} + raise ValueError( + f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}." + ) + + def extract_single_image(self, image: torch.Tensor): + image_np = image.cpu().numpy().squeeze(0) + + if self.conf.backend.startswith("pycolmap"): + if version.parse(pycolmap.__version__) >= version.parse("0.5.0"): + detections, descriptors = self.sift.extract(image_np) + scores = None # Scores are not exposed by COLMAP anymore. + else: + detections, scores, descriptors = self.sift.extract(image_np) + keypoints = detections[:, :2] # Keep only (x, y). + scales, angles = detections[:, -2:].T + if scores is not None and ( + self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda + ): + # Set the scores as a combination of abs. response and scale. + scores = np.abs(scores) * scales + elif self.conf.backend == "opencv": + # TODO: Check if opencv keypoints are already in corner convention + keypoints, scores, scales, angles, descriptors = run_opencv_sift( + self.sift, (image_np * 255.0).astype(np.uint8) + ) + pred = { + "keypoints": keypoints, + "scales": scales, + "oris": angles, + "descriptors": descriptors, + } + if scores is not None: + pred["keypoint_scores"] = scores + + # sometimes pycolmap returns points outside the image. We remove them + if self.conf.backend.startswith("pycolmap"): + is_inside = ( + pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]]) + ).all(-1) + pred = {k: v[is_inside] for k, v in pred.items()} + + if self.conf.nms_radius is not None: + keep = filter_dog_point( + pred["keypoints"], + pred["scales"], + pred["oris"], + image_np.shape, + self.conf.nms_radius, + scores=pred.get("keypoint_scores"), + ) + pred = {k: v[keep] for k, v in pred.items()} + + pred = {k: torch.from_numpy(v) for k, v in pred.items()} + if scores is not None: + # Keep the k keypoints with highest score + num_points = self.conf.max_num_keypoints + if num_points is not None and len(pred["keypoints"]) > num_points: + indices = torch.topk(pred["keypoint_scores"], num_points).indices + pred = {k: v[indices] for k, v in pred.items()} + + return pred + + def forward(self, data: dict) -> dict: + image = data["image"] + if image.shape[1] == 3: + image = rgb_to_grayscale(image) + device = image.device + image = image.cpu() + pred = [] + for k in range(len(image)): + img = image[k] + if "image_size" in data.keys(): + # avoid extracting points in padded areas + w, h = data["image_size"][k] + img = img[:, :h, :w] + p = self.extract_single_image(img) + pred.append(p) + pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} + if self.conf.rootsift: + pred["descriptors"] = sift_to_rootsift(pred["descriptors"]) + return pred diff --git a/imcui/third_party/LightGlue/lightglue/superpoint.py b/imcui/third_party/LightGlue/lightglue/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d380eb5737b92ce417e3ed0e5db1ee9d4a1d04 --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/superpoint.py @@ -0,0 +1,227 @@ +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +# Adapted by Remi Pautrat, Philipp Lindenberger + +import torch +from kornia.color import rgb_to_grayscale +from torch import nn + +from .utils import Extractor + + +def simple_nms(scores, nms_radius: int): + """Fast Non-maximum suppression to remove nearby points""" + assert nms_radius >= 0 + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius + ) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def top_k_keypoints(keypoints, scores, k): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0, sorted=True) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s: int = 8): + """Interpolate descriptors at keypoint locations""" + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor( + [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to( + keypoints + )[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {"align_corners": True} if torch.__version__ >= "1.3" else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + + +class SuperPoint(Extractor): + """SuperPoint Convolutional Detector and Descriptor + + SuperPoint: Self-Supervised Interest Point Detection and + Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew + Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 + + """ + + default_conf = { + "descriptor_dim": 256, + "nms_radius": 4, + "max_num_keypoints": None, + "detection_threshold": 0.0005, + "remove_borders": 4, + } + + preprocess_conf = { + "resize": 1024, + } + + required_data_keys = ["image"] + + def __init__(self, **conf): + super().__init__(**conf) # Update with default configuration. + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d( + c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0 + ) + + url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa + self.load_state_dict(torch.hub.load_state_dict_from_url(url)) + + if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0: + raise ValueError("max_num_keypoints must be positive or None") + + def forward(self, data: dict) -> dict: + """Compute keypoints, scores, descriptors for image""" + for key in self.required_data_keys: + assert key in data, f"Missing key {key} in data" + image = data["image"] + if image.shape[1] == 3: + image = rgb_to_grayscale(image) + + # Shared Encoder + x = self.relu(self.conv1a(image)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + scores = simple_nms(scores, self.conf.nms_radius) + + # Discard keypoints near the image borders + if self.conf.remove_borders: + pad = self.conf.remove_borders + scores[:, :pad] = -1 + scores[:, :, :pad] = -1 + scores[:, -pad:] = -1 + scores[:, :, -pad:] = -1 + + # Extract keypoints + best_kp = torch.where(scores > self.conf.detection_threshold) + scores = scores[best_kp] + + # Separate into batches + keypoints = [ + torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b) + ] + scores = [scores[best_kp[0] == i] for i in range(b)] + + # Keep the k keypoints with highest score + if self.conf.max_num_keypoints is not None: + keypoints, scores = list( + zip( + *[ + top_k_keypoints(k, s, self.conf.max_num_keypoints) + for k, s in zip(keypoints, scores) + ] + ) + ) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + # Extract descriptors + descriptors = [ + sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors) + ] + + return { + "keypoints": torch.stack(keypoints, 0), + "keypoint_scores": torch.stack(scores, 0), + "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(), + } diff --git a/imcui/third_party/LightGlue/lightglue/utils.py b/imcui/third_party/LightGlue/lightglue/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c1ab2e94716b1c54191a6ed5d01023036836c1 --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/utils.py @@ -0,0 +1,165 @@ +import collections.abc as collections +from pathlib import Path +from types import SimpleNamespace +from typing import Callable, List, Optional, Tuple, Union + +import cv2 +import kornia +import numpy as np +import torch + + +class ImagePreprocessor: + default_conf = { + "resize": None, # target edge length, None for no resizing + "side": "long", + "interpolation": "bilinear", + "align_corners": None, + "antialias": True, + } + + def __init__(self, **conf) -> None: + super().__init__() + self.conf = {**self.default_conf, **conf} + self.conf = SimpleNamespace(**self.conf) + + def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Resize and preprocess an image, return image and resize scale""" + h, w = img.shape[-2:] + if self.conf.resize is not None: + img = kornia.geometry.transform.resize( + img, + self.conf.resize, + side=self.conf.side, + antialias=self.conf.antialias, + align_corners=self.conf.align_corners, + ) + scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) + return img, scale + + +def map_tensor(input_, func: Callable): + string_classes = (str, bytes) + if isinstance(input_, string_classes): + return input_ + elif isinstance(input_, collections.Mapping): + return {k: map_tensor(sample, func) for k, sample in input_.items()} + elif isinstance(input_, collections.Sequence): + return [map_tensor(sample, func) for sample in input_] + elif isinstance(input_, torch.Tensor): + return func(input_) + else: + return input_ + + +def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True): + """Move batch (dict) to device""" + + def _func(tensor): + return tensor.to(device=device, non_blocking=non_blocking).detach() + + return map_tensor(batch, _func) + + +def rbd(data: dict) -> dict: + """Remove batch dimension from elements in data""" + return { + k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v + for k, v in data.items() + } + + +def read_image(path: Path, grayscale: bool = False) -> np.ndarray: + """Read an image from path as RGB or grayscale""" + if not Path(path).exists(): + raise FileNotFoundError(f"No image at path {path}.") + mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise IOError(f"Could not read image at {path}.") + if not grayscale: + image = image[..., ::-1] + return image + + +def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: + """Normalize the image tensor and reorder the dimensions.""" + if image.ndim == 3: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + elif image.ndim == 2: + image = image[None] # add channel axis + else: + raise ValueError(f"Not an image: {image.shape}") + return torch.tensor(image / 255.0, dtype=torch.float) + + +def resize_image( + image: np.ndarray, + size: Union[List[int], int], + fn: str = "max", + interp: Optional[str] = "area", +) -> np.ndarray: + """Resize an image to a fixed size, or according to max or min edge.""" + h, w = image.shape[:2] + + fn = {"max": max, "min": min}[fn] + if isinstance(size, int): + scale = size / fn(h, w) + h_new, w_new = int(round(h * scale)), int(round(w * scale)) + scale = (w_new / w, h_new / h) + elif isinstance(size, (tuple, list)): + h_new, w_new = size + scale = (w_new / w, h_new / h) + else: + raise ValueError(f"Incorrect new size: {size}") + mode = { + "linear": cv2.INTER_LINEAR, + "cubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + }[interp] + return cv2.resize(image, (w_new, h_new), interpolation=mode), scale + + +def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor: + image = read_image(path) + if resize is not None: + image, _ = resize_image(image, resize, **kwargs) + return numpy_image_to_torch(image) + + +class Extractor(torch.nn.Module): + def __init__(self, **conf): + super().__init__() + self.conf = SimpleNamespace(**{**self.default_conf, **conf}) + + @torch.no_grad() + def extract(self, img: torch.Tensor, **conf) -> dict: + """Perform extraction with online resizing""" + if img.dim() == 3: + img = img[None] # add batch dim + assert img.dim() == 4 and img.shape[0] == 1 + shape = img.shape[-2:][::-1] + img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) + feats = self.forward({"image": img}) + feats["image_size"] = torch.tensor(shape)[None].to(img).float() + feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 + return feats + + +def match_pair( + extractor, + matcher, + image0: torch.Tensor, + image1: torch.Tensor, + device: str = "cpu", + **preprocess, +): + """Match a pair of images (image0, image1) with an extractor and matcher""" + feats0 = extractor.extract(image0, **preprocess) + feats1 = extractor.extract(image1, **preprocess) + matches01 = matcher({"image0": feats0, "image1": feats1}) + data = [feats0, feats1, matches01] + # remove batch dim and move to target device + feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data] + return feats0, feats1, matches01 diff --git a/imcui/third_party/LightGlue/lightglue/viz2d.py b/imcui/third_party/LightGlue/lightglue/viz2d.py new file mode 100644 index 0000000000000000000000000000000000000000..22dc3f65662181666b1ff57403af4ad18bfdc271 --- /dev/null +++ b/imcui/third_party/LightGlue/lightglue/viz2d.py @@ -0,0 +1,184 @@ +""" +2D visualization primitives based on Matplotlib. +1) Plot images with `plot_images`. +2) Call `plot_keypoints` or `plot_matches` any number of times. +3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. +""" + +import matplotlib +import matplotlib.patheffects as path_effects +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def cm_RdGn(x): + """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" + x = np.clip(x, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) + return np.clip(c, 0, 1) + + +def cm_BlRdGn(x_): + """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" + x = np.clip(x_, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]]) + + xn = -np.clip(x_, -1, 0)[..., None] * 2 + cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]]) + out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) + return out + + +def cm_prune(x_): + """Custom colormap to visualize pruning""" + if isinstance(x_, torch.Tensor): + x_ = x_.cpu().numpy() + max_i = max(x_) + norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9) + return cm_BlRdGn(norm_x) + + +def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): + """Plot a set of images horizontally. + Args: + imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + adaptive: whether the figure size should fit the image aspect ratios. + """ + # conversion to (H, W, 3) for torch.Tensor + imgs = [ + img.permute(1, 2, 0).cpu().numpy() + if (isinstance(img, torch.Tensor) and img.dim() == 3) + else img + for img in imgs + ] + + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + if adaptive: + ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H + else: + ratios = [4 / 3] * n + figsize = [sum(ratios) * 4.5, 4.5] + fig, ax = plt.subplots( + 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + fig.tight_layout(pad=pad) + + +def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0): + """Plot keypoints for existing images. + Args: + kpts: list of ndarrays of size (N, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float. + """ + if not isinstance(colors, list): + colors = [colors] * len(kpts) + if not isinstance(a, list): + a = [a] * len(kpts) + if axes is None: + axes = plt.gcf().axes + for ax, k, c, alpha in zip(axes, kpts, colors, a): + if isinstance(k, torch.Tensor): + k = k.cpu().numpy() + ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha) + + +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None): + """Plot matches for a pair of existing images. + Args: + kpts0, kpts1: corresponding keypoints of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + ps: size of the end points (no endpoint if ps=0) + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + if axes is None: + ax = fig.axes + ax0, ax1 = ax[0], ax[1] + else: + ax0, ax1 = axes + if isinstance(kpts0, torch.Tensor): + kpts0 = kpts0.cpu().numpy() + if isinstance(kpts1, torch.Tensor): + kpts1 = kpts1.cpu().numpy() + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + for i in range(len(kpts0)): + line = matplotlib.patches.ConnectionPatch( + xyA=(kpts0[i, 0], kpts0[i, 1]), + xyB=(kpts1[i, 0], kpts1[i, 1]), + coordsA=ax0.transData, + coordsB=ax1.transData, + axesA=ax0, + axesB=ax1, + zorder=1, + color=color[i], + linewidth=lw, + clip_on=True, + alpha=a, + label=None if labels is None else labels[i], + picker=5.0, + ) + line.set_annotation_clip(True) + fig.add_artist(line) + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + if ps > 0: + ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def add_text( + idx, + text, + pos=(0.01, 0.99), + fs=15, + color="w", + lcolor="k", + lwidth=2, + ha="left", + va="top", +): + ax = plt.gcf().axes[idx] + t = ax.text( + *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes + ) + if lcolor is not None: + t.set_path_effects( + [ + path_effects.Stroke(linewidth=lwidth, foreground=lcolor), + path_effects.Normal(), + ] + ) + + +def save_plot(path, **kw): + """Save the current figure without any white margin.""" + plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) diff --git a/imcui/third_party/RoMa/demo/demo_3D_effect.py b/imcui/third_party/RoMa/demo/demo_3D_effect.py new file mode 100644 index 0000000000000000000000000000000000000000..ae26caaf92deb884dfabb6eca96aec3406325c3f --- /dev/null +++ b/imcui/third_party/RoMa/demo/demo_3D_effect.py @@ -0,0 +1,47 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +from romatch.utils.utils import tensor_to_pil + +from romatch import roma_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.backends.mps.is_available(): + device = torch.device('mps') + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + # Create model + roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152)) + roma_model.symmetric = False + + H, W = roma_model.get_output_resolution() + + im1 = Image.open(im1_path).resize((W, H)) + im2 = Image.open(im2_path).resize((W, H)) + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sampling not needed, but can be done with model.sample(warp, certainty) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + coords_A, coords_B = warp[...,:2], warp[...,2:] + for i, x in enumerate(np.linspace(0,2*np.pi,200)): + t = (1 + np.cos(x))/2 + interp_warp = (1-t)*coords_A + t*coords_B + im2_transfer_rgb = F.grid_sample( + x2[None], interp_warp[None], mode="bilinear", align_corners=False + )[0] + tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg") \ No newline at end of file diff --git a/imcui/third_party/RoMa/demo/demo_fundamental.py b/imcui/third_party/RoMa/demo/demo_fundamental.py new file mode 100644 index 0000000000000000000000000000000000000000..65ea9ccb76525da3e88e4f426bdebdc4fe742161 --- /dev/null +++ b/imcui/third_party/RoMa/demo/demo_fundamental.py @@ -0,0 +1,34 @@ +from PIL import Image +import torch +import cv2 +from romatch import roma_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.backends.mps.is_available(): + device = torch.device('mps') + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + + # Create model + roma_model = roma_outdoor(device=device) + + + W_A, H_A = Image.open(im1_path).size + W_B, H_B = Image.open(im2_path).size + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sample matches for estimation + matches, certainty = roma_model.sample(warp, certainty) + kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + F, mask = cv2.findFundamentalMat( + kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 + ) \ No newline at end of file diff --git a/imcui/third_party/RoMa/demo/demo_match.py b/imcui/third_party/RoMa/demo/demo_match.py new file mode 100644 index 0000000000000000000000000000000000000000..582767e19d8b50c6c241ea32f81cabb38f52fce2 --- /dev/null +++ b/imcui/third_party/RoMa/demo/demo_match.py @@ -0,0 +1,50 @@ +import os +os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' +import torch +from PIL import Image +import torch.nn.functional as F +import numpy as np +from romatch.utils.utils import tensor_to_pil + +from romatch import roma_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.backends.mps.is_available(): + device = torch.device('mps') + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + # Create model + roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152)) + + H, W = roma_model.get_output_resolution() + + im1 = Image.open(im1_path).resize((W, H)) + im2 = Image.open(im2_path).resize((W, H)) + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sampling not needed, but can be done with model.sample(warp, certainty) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + im2_transfer_rgb = F.grid_sample( + x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + im1_transfer_rgb = F.grid_sample( + x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + tensor_to_pil(vis_im, unnormalize=False).save(save_path) \ No newline at end of file diff --git a/imcui/third_party/RoMa/demo/demo_match_opencv_sift.py b/imcui/third_party/RoMa/demo/demo_match_opencv_sift.py new file mode 100644 index 0000000000000000000000000000000000000000..3196fcfaab248f6c4c6247a0afb4db745206aee8 --- /dev/null +++ b/imcui/third_party/RoMa/demo/demo_match_opencv_sift.py @@ -0,0 +1,43 @@ +from PIL import Image +import numpy as np + +import numpy as np +import cv2 as cv +import matplotlib.pyplot as plt + + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage + img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage + # Initiate SIFT detector + sift = cv.SIFT_create() + # find the keypoints and descriptors with SIFT + kp1, des1 = sift.detectAndCompute(img1,None) + kp2, des2 = sift.detectAndCompute(img2,None) + # BFMatcher with default params + bf = cv.BFMatcher() + matches = bf.knnMatch(des1,des2,k=2) + # Apply ratio test + good = [] + for m,n in matches: + if m.distance < 0.75*n.distance: + good.append([m]) + # cv.drawMatchesKnn expects list of lists as matches. + draw_params = dict(matchColor = (255,0,0), # draw matches in red color + singlePointColor = None, + flags = 2) + + img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params) + Image.fromarray(img3).save("demo/sift_matches.png") diff --git a/imcui/third_party/RoMa/demo/demo_match_tiny.py b/imcui/third_party/RoMa/demo/demo_match_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e66a4b80a2361e22673ddc59632f48ad653b69 --- /dev/null +++ b/imcui/third_party/RoMa/demo/demo_match_tiny.py @@ -0,0 +1,77 @@ +import os +os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' +import torch +from PIL import Image +import torch.nn.functional as F +import numpy as np +from romatch.utils.utils import tensor_to_pil + +from romatch import tiny_roma_v1_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.backends.mps.is_available(): + device = torch.device('mps') + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str) + parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + + # Create model + roma_model = tiny_roma_v1_outdoor(device=device) + + # Match + warp, certainty1 = roma_model.match(im1_path, im2_path) + + h1, w1 = warp.shape[:2] + + # maybe im1.size != im2.size + im1 = Image.open(im1_path).resize((w1, h1)) + im2 = Image.open(im2_path) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + h2, w2 = x2.shape[1:] + g1_p2x = w2 / 2 * (warp[..., 2] + 1) + g1_p2y = h2 / 2 * (warp[..., 3] + 1) + g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2 + g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2 + + x, y = torch.meshgrid( + torch.arange(w1, device=device), + torch.arange(h1, device=device), + indexing="xy", + ) + g2x = torch.round(g1_p2x[y, x]).long() + g2y = torch.round(g1_p2y[y, x]).long() + idx_x = torch.bitwise_and(0 <= g2x, g2x < w2) + idx_y = torch.bitwise_and(0 <= g2y, g2y < h2) + idx = torch.bitwise_and(idx_x, idx_y) + g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1 + g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1 + + certainty2 = F.grid_sample( + certainty1[None][None], + torch.stack([g2_p1x, g2_p1y], dim=2)[None], + mode="bilinear", + align_corners=False, + )[0] + + white_im1 = torch.ones((h1, w1), device = device) + white_im2 = torch.ones((h2, w2), device = device) + + certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0] + certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0] + + vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1 + vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2 + + tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path) + tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path) \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/__init__.py b/imcui/third_party/RoMa/romatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45c2aff7bb510b1d8dfcdc9c4d91ffc3aa5021f6 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/__init__.py @@ -0,0 +1,8 @@ +import os +from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor + +DEBUG_MODE = False +RANK = int(os.environ.get('RANK', default = 0)) +GLOBAL_STEP = 0 +STEP_SIZE = 1 +LOCAL_RANK = -1 \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/benchmarks/__init__.py b/imcui/third_party/RoMa/romatch/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af32a46ba4a48d719e3ad38f9b2355a13fe6cc44 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/benchmarks/__init__.py @@ -0,0 +1,6 @@ +from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark +from .scannet_benchmark import ScanNetBenchmark +from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark +from .megadepth_dense_benchmark import MegadepthDenseBenchmark +from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark +#from .scannet_benchmark_poselib import ScanNetPoselibBenchmark \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..5972361f80d4f4e5cafd8fd359c87c0433a0a5a5 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py @@ -0,0 +1,113 @@ +from PIL import Image +import numpy as np + +import os + +from tqdm import tqdm +from romatch.utils import pose_auc +import cv2 + + +class HpatchesHomogBenchmark: + """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]""" + + def __init__(self, dataset_path) -> None: + seqs_dir = "hpatches-sequences-release" + self.seqs_path = os.path.join(dataset_path, seqs_dir) + self.seq_names = sorted(os.listdir(self.seqs_path)) + # Ignore seqs is same as LoFTR. + self.ignore_seqs = set( + [ + "i_contruction", + "i_crownnight", + "i_dc", + "i_pencils", + "i_whitebuilding", + "v_artisans", + "v_astronautis", + "v_talent", + ] + ) + + def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup): + offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think) + im_A_coords = ( + np.stack( + ( + wq * (im_A_coords[..., 0] + 1) / 2, + hq * (im_A_coords[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + im_A_to_im_B = ( + np.stack( + ( + wsup * (im_A_to_im_B[..., 0] + 1) / 2, + hsup * (im_A_to_im_B[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + return im_A_coords, im_A_to_im_B + + def benchmark(self, model, model_name = None): + n_matches = [] + homog_dists = [] + for seq_idx, seq_name in tqdm( + enumerate(self.seq_names), total=len(self.seq_names) + ): + im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm") + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + for im_idx in range(2, 7): + im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm") + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + H = np.loadtxt( + os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) + ) + dense_matches, dense_certainty = model.match( + im_A_path, im_B_path + ) + good_matches, _ = model.sample(dense_matches, dense_certainty, 5000) + pos_a, pos_b = self.convert_coordinates( + good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2 + ) + try: + H_pred, inliers = cv2.findHomography( + pos_a, + pos_b, + method = cv2.RANSAC, + confidence = 0.99999, + ransacReprojThreshold = 3 * min(w2, h2) / 480, + ) + except: + H_pred = None + if H_pred is None: + H_pred = np.zeros((3, 3)) + H_pred[2, 2] = 1.0 + corners = np.array( + [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]] + ) + real_warped_corners = np.dot(corners, np.transpose(H)) + real_warped_corners = ( + real_warped_corners[:, :2] / real_warped_corners[:, 2:] + ) + warped_corners = np.dot(corners, np.transpose(H_pred)) + warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] + mean_dist = np.mean( + np.linalg.norm(real_warped_corners - warped_corners, axis=1) + ) / (min(w2, h2) / 480.0) + homog_dists.append(mean_dist) + + n_matches = np.array(n_matches) + thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + auc = pose_auc(np.array(homog_dists), thresholds) + return { + "hpatches_homog_auc_3": auc[2], + "hpatches_homog_auc_5": auc[4], + "hpatches_homog_auc_10": auc[9], + } diff --git a/imcui/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..09d0a297f2937afc609eed3a74aa0c3c4c7ccebc --- /dev/null +++ b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py @@ -0,0 +1,106 @@ +import torch +import numpy as np +import tqdm +from romatch.datasets import MegadepthBuilder +from romatch.utils import warp_kpts +from torch.utils.data import ConcatDataset +import romatch + +class MegadepthDenseBenchmark: + def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None: + mega = MegadepthBuilder(data_root=data_root) + self.dataset = ConcatDataset( + mega.build_scenes(split="test_loftr", ht=h, wt=w) + ) # fixed resolution of 384,512 + self.num_samples = num_samples + + def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches): + b, h1, w1, d = dense_matches.shape + with torch.no_grad(): + x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2) + mask, x2 = warp_kpts( + x1.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + ) + x2 = torch.stack( + (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1 + ) + prob = mask.float().reshape(b, h1, w1) + x2_hat = dense_matches[..., 2:] + x2_hat = torch.stack( + (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1 + ) + gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1) + gd = gd[prob == 1] + pck_1 = (gd < 1.0).float().mean() + pck_3 = (gd < 3.0).float().mean() + pck_5 = (gd < 5.0).float().mean() + return gd, pck_1, pck_3, pck_5, prob + + def benchmark(self, model, batch_size=8): + model.train(False) + with torch.no_grad(): + gd_tot = 0.0 + pck_1_tot = 0.0 + pck_3_tot = 0.0 + pck_5_tot = 0.0 + sampler = torch.utils.data.WeightedRandomSampler( + torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples + ) + B = batch_size + dataloader = torch.utils.data.DataLoader( + self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler + ) + for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0): + im_A, im_B, depth1, depth2, T_1to2, K1, K2 = ( + data["im_A"].cuda(), + data["im_B"].cuda(), + data["im_A_depth"].cuda(), + data["im_B_depth"].cuda(), + data["T_1to2"].cuda(), + data["K1"].cuda(), + data["K2"].cuda(), + ) + matches, certainty = model.match(im_A, im_B, batched=True) + gd, pck_1, pck_3, pck_5, prob = self.geometric_dist( + depth1, depth2, T_1to2, K1, K2, matches + ) + if romatch.DEBUG_MODE: + from romatch.utils.utils import tensor_to_pil + import torch.nn.functional as F + path = "vis" + H, W = model.get_output_resolution() + white_im = torch.ones((B,1,H,W),device="cuda") + im_B_transfer_rgb = F.grid_sample( + im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False + ) + warp_im = im_B_transfer_rgb + c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None] + vis_im = c_b * warp_im + (1 - c_b) * white_im + for b in range(B): + import os + os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True) + tensor_to_pil(vis_im[b], unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg") + tensor_to_pil(im_A[b].cuda(), unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg") + tensor_to_pil(im_B[b].cuda(), unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg") + + + gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = ( + gd_tot + gd.mean(), + pck_1_tot + pck_1, + pck_3_tot + pck_3, + pck_5_tot + pck_5, + ) + return { + "epe": gd_tot.item() / len(dataloader), + "mega_pck_1": pck_1_tot.item() / len(dataloader), + "mega_pck_3": pck_3_tot.item() / len(dataloader), + "mega_pck_5": pck_5_tot.item() / len(dataloader), + } diff --git a/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..36f293f9556d919643f6f39156314a5b402d9082 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py @@ -0,0 +1,118 @@ +import numpy as np +import torch +from romatch.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F +import romatch +import kornia.geometry.epipolar as kepi + +class MegaDepthPoseEstimationBenchmark: + def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model, model_name = None): + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/{im_paths[idx1]}" + im_B_path = f"{data_root}/{im_paths[idx2]}" + dense_matches, dense_certainty = model.match( + im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy() + ) + sparse_matches,_ = model.sample( + dense_matches, dense_certainty, 5_000 + ) + + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False. + scale1 = 1200 / max(w1, h1) + scale2 = 1200 / max(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1, K2 = K1.copy(), K2.copy() + K1[:2] = K1[:2] * scale1 + K2[:2] = K2[:2] * scale2 + + kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2) + kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy() + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + threshold = 0.5 + norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py new file mode 100644 index 0000000000000000000000000000000000000000..4732ccf2af5b50e6db60831d7c63c5bf70ec727c --- /dev/null +++ b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py @@ -0,0 +1,119 @@ +import numpy as np +import torch +from romatch.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F +import romatch +import kornia.geometry.epipolar as kepi + +# wrap cause pyposelib is still in dev +# will add in deps later +import poselib + +class Mega1500PoseLibBenchmark: + def __init__(self, data_root="data/megadepth", scene_names = None, num_ransac_iter = 5, test_every = 1) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + self.num_ransac_iter = num_ransac_iter + self.test_every = test_every + + def benchmark(self, model, model_name = None): + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs))[::self.test_every] + for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/{im_paths[idx1]}" + im_B_path = f"{data_root}/{im_paths[idx2]}" + dense_matches, dense_certainty = model.match( + im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy() + ) + sparse_matches,_ = model.sample( + dense_matches, dense_certainty, 5_000 + ) + + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2) + kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy() + for _ in range(self.num_ransac_iter): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + threshold = 1 + camera1 = {'model': 'PINHOLE', 'width': w1, 'height': h1, 'params': K1[[0,1,0,1], [0,1,2,2]]} + camera2 = {'model': 'PINHOLE', 'width': w2, 'height': h2, 'params': K2[[0,1,0,1], [0,1,2,2]]} + relpose, res = poselib.estimate_relative_pose( + kpts1, + kpts2, + camera1, + camera2, + ransac_opt = {"max_reproj_error": 2*threshold, "max_epipolar_error": threshold, "min_inliers": 8, "max_iterations": 10_000}, + ) + Rt_est = relpose.Rt + R_est, t_est = Rt_est[:3,:3], Rt_est[:3,3:] + mask = np.array(res['inliers']).astype(np.float32) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}") + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6b3c0eeb1c211d224edde84974529c66b52460 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py @@ -0,0 +1,143 @@ +import os.path as osp +import numpy as np +import torch +from romatch.utils import * +from PIL import Image +from tqdm import tqdm + + +class ScanNetBenchmark: + def __init__(self, data_root="data/scannet") -> None: + self.data_root = data_root + + def benchmark(self, model, model_name = None): + model.train(False) + with torch.no_grad(): + data_root = self.data_root + tmp = np.load(osp.join(data_root, "test.npz")) + pairs, rel_pose = tmp["name"], tmp["rel_pose"] + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + pair_inds = np.random.choice( + range(len(pairs)), size=len(pairs), replace=False + ) + for pairind in tqdm(pair_inds, smoothing=0.9): + scene = pairs[pairind] + scene_name = f"scene0{scene[0]}_00" + im_A_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[2]}.jpg", + ) + im_A = Image.open(im_A_path) + im_B_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[3]}.jpg", + ) + im_B = Image.open(im_B_path) + T_gt = rel_pose[pairind].reshape(3, 4) + R, t = T_gt[:3, :3], T_gt[:3, 3] + K = np.stack( + [ + np.array([float(i) for i in r.split()]) + for r in open( + osp.join( + self.data_root, + "scans_test", + scene_name, + "intrinsic", + "intrinsic_color.txt", + ), + "r", + ) + .read() + .split("\n") + if r + ] + ) + w1, h1 = im_A.size + w2, h2 = im_B.size + K1 = K.copy() + K2 = K.copy() + dense_matches, dense_certainty = model.match(im_A_path, im_B_path) + sparse_matches, sparse_certainty = model.sample( + dense_matches, dense_certainty, 5000 + ) + scale1 = 480 / min(w1, h1) + scale2 = 480 / min(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1 = K1 * scale1 + K2 = K2 * scale2 + + offset = 0.5 + kpts1 = sparse_matches[:, :2] + kpts1 = ( + np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2 - offset, + h1 * (kpts1[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2 - offset, + h2 * (kpts2[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + norm_threshold = 0.5 / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/RoMa/romatch/checkpointing/__init__.py b/imcui/third_party/RoMa/romatch/checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/checkpointing/__init__.py @@ -0,0 +1 @@ +from .checkpoint import CheckPoint diff --git a/imcui/third_party/RoMa/romatch/checkpointing/checkpoint.py b/imcui/third_party/RoMa/romatch/checkpointing/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5a5131322241475323f1b992d9d3f7b21dbdac --- /dev/null +++ b/imcui/third_party/RoMa/romatch/checkpointing/checkpoint.py @@ -0,0 +1,60 @@ +import os +import torch +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +from loguru import logger +import gc + +import romatch + +class CheckPoint: + def __init__(self, dir=None, name="tmp"): + self.name = name + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + + def save( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if romatch.RANK == 0: + assert model is not None + if isinstance(model, (DataParallel, DistributedDataParallel)): + model = model.module + states = { + "model": model.state_dict(), + "n": n, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + } + torch.save(states, self.dir + self.name + f"_latest.pth") + logger.info(f"Saved states {list(states.keys())}, at step {n}") + + def load( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0: + states = torch.load(self.dir + self.name + f"_latest.pth") + if "model" in states: + model.load_state_dict(states["model"]) + if "n" in states: + n = states["n"] if states["n"] else n + if "optimizer" in states: + try: + optimizer.load_state_dict(states["optimizer"]) + except Exception as e: + print(f"Failed to load states for optimizer, with error {e}") + if "lr_scheduler" in states: + lr_scheduler.load_state_dict(states["lr_scheduler"]) + print(f"Loaded states {list(states.keys())}, at step {n}") + del states + gc.collect() + torch.cuda.empty_cache() + return model, optimizer, lr_scheduler, n \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/datasets/__init__.py b/imcui/third_party/RoMa/romatch/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b60c709926a4a7bd019b73eac10879063a996c90 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/datasets/__init__.py @@ -0,0 +1,2 @@ +from .megadepth import MegadepthBuilder +from .scannet import ScanNetBuilder \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/datasets/megadepth.py b/imcui/third_party/RoMa/romatch/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..88f775ef412f7bd062cc9a1d67d95a030e7a15dd --- /dev/null +++ b/imcui/third_party/RoMa/romatch/datasets/megadepth.py @@ -0,0 +1,232 @@ +import os +from PIL import Image +import h5py +import numpy as np +import torch +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +import romatch +from romatch.utils import * +import math + +class MegadepthScene: + def __init__( + self, + data_root, + scene_info, + ht=384, + wt=512, + min_overlap=0.0, + max_overlap=1.0, + shake_t=0, + rot_prob=0.0, + normalize=True, + max_num_pairs = 100_000, + scene_name = None, + use_horizontal_flip_aug = False, + use_single_horizontal_flip_aug = False, + colorjiggle_params = None, + random_eraser = None, + use_randaug = False, + randaug_params = None, + randomize_size = False, + ) -> None: + self.data_root = data_root + self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}" + self.image_paths = scene_info["image_paths"] + self.depth_paths = scene_info["depth_paths"] + self.intrinsics = scene_info["intrinsics"] + self.poses = scene_info["poses"] + self.pairs = scene_info["pairs"] + self.overlaps = scene_info["overlaps"] + threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap) + self.pairs = self.pairs[threshold] + self.overlaps = self.overlaps[threshold] + if len(self.pairs) > max_num_pairs: + pairinds = np.random.choice( + np.arange(0, len(self.pairs)), max_num_pairs, replace=False + ) + self.pairs = self.pairs[pairinds] + self.overlaps = self.overlaps[pairinds] + if randomize_size: + area = ht * wt + s = int(16 * (math.sqrt(area)//16)) + sizes = ((ht,wt), (s,s), (wt,ht)) + choice = romatch.RANK % 3 + ht, wt = sizes[choice] + # counts, bins = np.histogram(self.overlaps,20) + # print(counts) + self.im_transform_ops = get_tuple_transform_ops( + resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params, + ) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt) + ) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.random_eraser = random_eraser + if use_horizontal_flip_aug and use_single_horizontal_flip_aug: + raise ValueError("Can't both flip both images and only flip one") + self.use_horizontal_flip_aug = use_horizontal_flip_aug + self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug + self.use_randaug = use_randaug + + def load_im(self, im_path): + im = Image.open(im_path) + return im + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + im_A = im_A.flip(-1) + im_B = im_B.flip(-1) + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) + K_A = flip_mat@K_A + K_B = flip_mat@K_B + + return im_A, im_B, depth_A, depth_B, K_A, K_B + + def load_depth(self, depth_ref, crop=None): + depth = np.array(h5py.File(depth_ref, "r")["depth"]) + return torch.from_numpy(depth) + + def __len__(self): + return len(self.pairs) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K + + def rand_shake(self, *things): + t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2) + return [ + tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0]) + for thing in things + ], t + + def __getitem__(self, pair_idx): + # read intrinsics of original size + idx1, idx2 = self.pairs[pair_idx] + K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3) + K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T1 = self.poses[idx1] + T2 = self.poses[idx2] + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) + + # Load positive pair data + im_A, im_B = self.image_paths[idx1], self.image_paths[idx2] + depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2] + im_A_ref = os.path.join(self.data_root, im_A) + im_B_ref = os.path.join(self.data_root, im_B) + depth_A_ref = os.path.join(self.data_root, depth1) + depth_B_ref = os.path.join(self.data_root, depth2) + im_A = self.load_im(im_A_ref) + im_B = self.load_im(im_B_ref) + K1 = self.scale_intrinsic(K1, im_A.width, im_A.height) + K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) + + if self.use_randaug: + im_A, im_B = self.rand_augment(im_A, im_B) + + depth_A = self.load_depth(depth_A_ref) + depth_B = self.load_depth(depth_B_ref) + # Process images + im_A, im_B = self.im_transform_ops((im_A, im_B)) + depth_A, depth_B = self.depth_transform_ops( + (depth_A[None, None], depth_B[None, None]) + ) + + [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B) + K1[:2, 2] += t + K2[:2, 2] += t + + im_A, im_B = im_A[None], im_B[None] + if self.random_eraser is not None: + im_A, depth_A = self.random_eraser(im_A, depth_A) + im_B, depth_B = self.random_eraser(im_B, depth_B) + + if self.use_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + if self.use_single_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2) + + if romatch.DEBUG_MODE: + tensor_to_pil(im_A[0], unnormalize=True).save( + f"vis/im_A.jpg") + tensor_to_pil(im_B[0], unnormalize=True).save( + f"vis/im_B.jpg") + + data_dict = { + "im_A": im_A[0], + "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], + "im_B": im_B[0], + "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0], + "im_A_depth": depth_A[0, 0], + "im_B_depth": depth_B[0, 0], + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + "im_A_path": im_A_ref, + "im_B_path": im_B_ref, + + } + return data_dict + + +class MegadepthBuilder: + def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root, "prep_scene_info") + self.all_scenes = os.listdir(self.scene_info_root) + self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] + # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those + self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy']) + self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy']) + self.test_scenes_loftr = ["0015.npy", "0022.npy"] + self.loftr_ignore = loftr_ignore + self.imc21_ignore = imc21_ignore + + def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs): + if split == "train": + scene_names = set(self.all_scenes) - set(self.test_scenes) + elif split == "train_loftr": + scene_names = set(self.all_scenes) - set(self.test_scenes_loftr) + elif split == "test": + scene_names = self.test_scenes + elif split == "test_loftr": + scene_names = self.test_scenes_loftr + elif split == "custom": + scene_names = scene_names + else: + raise ValueError(f"Split {split} not available") + scenes = [] + for scene_name in scene_names: + if self.loftr_ignore and scene_name in self.loftr_ignore_scenes: + continue + if self.imc21_ignore and scene_name in self.imc21_scenes: + continue + if ".npy" not in scene_name: + continue + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ).item() + scenes.append( + MegadepthScene( + self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs + ) + ) + return scenes + + def weight_scenes(self, concat_dataset, alpha=0.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) + return ws diff --git a/imcui/third_party/RoMa/romatch/datasets/scannet.py b/imcui/third_party/RoMa/romatch/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..e03261557147cd3449c76576a5e5e22c0ae288e9 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/datasets/scannet.py @@ -0,0 +1,160 @@ +import os +import random +from PIL import Image +import cv2 +import h5py +import numpy as np +import torch +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset) + +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +import os.path as osp +import matplotlib.pyplot as plt +import romatch +from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +from romatch.utils.transforms import GeometricSequential +from tqdm import tqdm + +class ScanNetScene: + def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False, +) -> None: + self.scene_root = osp.join(data_root,"scans","scans_train") + self.data_names = scene_info['name'] + self.overlaps = scene_info['score'] + # Only sample 10s + valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0 + self.overlaps = self.overlaps[valid] + self.data_names = self.data_names[valid] + if len(self.data_names) > 10000: + pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False) + self.data_names = self.data_names[pairinds] + self.overlaps = self.overlaps[pairinds] + self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) + self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) + self.use_horizontal_flip_aug = use_horizontal_flip_aug + + def load_im(self, im_B, crop=None): + im = Image.open(im_B) + return im + + def load_depth(self, depth_ref, crop=None): + depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + def __len__(self): + return len(self.data_names) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], + [0, sy, 0], + [0, 0, 1]]) + return sK@K + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + im_A = im_A.flip(-1) + im_B = im_B.flip(-1) + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) + K_A = flip_mat@K_A + K_B = flip_mat@K_B + + return im_A, im_B, depth_A, depth_B, K_A, K_B + def read_scannet_pose(self,path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = np.linalg.inv(cam2world) + return world2cam + + + def read_scannet_intrinsic(self,path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float) + + def __getitem__(self, pair_idx): + # read intrinsics of original size + data_name = self.data_names[pair_idx] + scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the intrinsic of depthmap + K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root, + scene_name, + 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter + # read and compute relative poses + T1 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_1}.txt')) + T2 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_2}.txt')) + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4) + + # Load positive pair data + im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg') + im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg') + depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png') + depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png') + + im_A = self.load_im(im_A_ref) + im_B = self.load_im(im_B_ref) + depth_A = self.load_depth(depth_A_ref) + depth_B = self.load_depth(depth_B_ref) + + # Recompute camera intrinsic matrix due to the resize + K1 = self.scale_intrinsic(K1, im_A.width, im_A.height) + K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) + # Process images + im_A, im_B = self.im_transform_ops((im_A, im_B)) + depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None])) + if self.use_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + + data_dict = {'im_A': im_A, + 'im_B': im_B, + 'im_A_depth': depth_A[0,0], + 'im_B_depth': depth_B[0,0], + 'K1': K1, + 'K2': K2, + 'T_1to2':T_1to2, + } + return data_dict + + +class ScanNetBuilder: + def __init__(self, data_root = 'data/scannet') -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root,'scannet_indices') + self.all_scenes = os.listdir(self.scene_info_root) + + def build_scenes(self, split = 'train', min_overlap=0., **kwargs): + # Note: split doesn't matter here as we always use same scannet_train scenes + scene_names = self.all_scenes + scenes = [] + for scene_name in tqdm(scene_names, disable = romatch.RANK > 0): + scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True) + scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs)) + return scenes + + def weight_scenes(self, concat_dataset, alpha=.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n)/n**alpha for n in ns]) + return ws diff --git a/imcui/third_party/RoMa/romatch/losses/__init__.py b/imcui/third_party/RoMa/romatch/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e08abacfc0f83d7de0f2ddc0583766a80bf53cf --- /dev/null +++ b/imcui/third_party/RoMa/romatch/losses/__init__.py @@ -0,0 +1 @@ +from .robust_loss import RobustLosses \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/losses/robust_loss.py b/imcui/third_party/RoMa/romatch/losses/robust_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..80d430069666fabe2471ec7eda2fa6e9c996f041 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/losses/robust_loss.py @@ -0,0 +1,161 @@ +from einops.einops import rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F +from romatch.utils.utils import get_gt_warp +import wandb +import romatch +import math + +class RobustLosses(nn.Module): + def __init__( + self, + robust=False, + center_coords=False, + scale_normalize=False, + ce_weight=0.01, + local_loss=True, + local_dist=4.0, + local_largest_scale=8, + smooth_mask = False, + depth_interpolation_mode = "bilinear", + mask_depth_loss = False, + relative_depth_error_threshold = 0.05, + alpha = 1., + c = 1e-3, + ): + super().__init__() + self.robust = robust # measured in pixels + self.center_coords = center_coords + self.scale_normalize = scale_normalize + self.ce_weight = ce_weight + self.local_loss = local_loss + self.local_dist = local_dist + self.local_largest_scale = local_largest_scale + self.smooth_mask = smooth_mask + self.depth_interpolation_mode = depth_interpolation_mode + self.mask_depth_loss = mask_depth_loss + self.relative_depth_error_threshold = relative_depth_error_threshold + self.avg_overlap = dict() + self.alpha = alpha + self.c = c + + def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale): + with torch.no_grad(): + B, C, H, W = scale_gm_cls.shape + device = x2.device + cls_res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)], indexing='ij') + G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) + GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices + cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99] + certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) + if not torch.any(cls_loss): + cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere + + losses = { + f"gm_certainty_loss_{scale}": certainty_loss.mean(), + f"gm_cls_loss_{scale}": cls_loss.mean(), + } + wandb.log(losses, step = romatch.GLOBAL_STEP) + return losses + + def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale): + with torch.no_grad(): + B, C, H, W = delta_cls.shape + device = x2.device + cls_res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) + G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale + GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices + cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99] + certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob) + if not torch.any(cls_loss): + cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere + losses = { + f"delta_certainty_loss_{scale}": certainty_loss.mean(), + f"delta_cls_loss_{scale}": cls_loss.mean(), + } + wandb.log(losses, step = romatch.GLOBAL_STEP) + return losses + + def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): + epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) + if scale == 1: + pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() + wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP) + + ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) + a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha + cs = self.c * scale + x = epe[prob > 0.99] + reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) + if not torch.any(reg_loss): + reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere + losses = { + f"{mode}_certainty_loss_{scale}": ce_loss.mean(), + f"{mode}_regression_loss_{scale}": reg_loss.mean(), + } + wandb.log(losses, step = romatch.GLOBAL_STEP) + return losses + + def forward(self, corresps, batch): + scales = list(corresps.keys()) + tot_loss = 0.0 + # scale_weights due to differences in scale for regression gradients and classification gradients + scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1} + for scale in scales: + scale_corresps = corresps[scale] + scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = ( + scale_corresps["certainty"], + scale_corresps.get("flow_pre_delta"), + scale_corresps.get("delta_cls"), + scale_corresps.get("offset_scale"), + scale_corresps.get("gm_cls"), + scale_corresps.get("gm_certainty"), + scale_corresps["flow"], + scale_corresps.get("gm_flow"), + + ) + if flow_pre_delta is not None: + flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") + b, h, w, d = flow_pre_delta.shape + else: + # _ = 1 + b, _, h, w = scale_certainty.shape + gt_warp, gt_prob = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=h, + W=w, + ) + x2 = gt_warp.float() + prob = gt_prob + + if self.local_largest_scale >= scale: + prob = prob * ( + F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0] + < (2 / 512) * (self.local_dist[scale] * scale)) + + if scale_gm_cls is not None: + gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale) + gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * gm_loss + elif scale_gm_flow is not None: + gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") + gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * gm_loss + + if delta_cls is not None: + delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale) + delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss + else: + delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) + reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * reg_loss + prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach() + return tot_loss diff --git a/imcui/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py b/imcui/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py new file mode 100644 index 0000000000000000000000000000000000000000..a17c24678b093ca843d16c1a17ea16f19fa594d5 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py @@ -0,0 +1,160 @@ +from einops.einops import rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F +from romatch.utils.utils import get_gt_warp +import wandb +import romatch +import math + +# This is slightly different than regular romatch due to significantly worse corresps +# The confidence loss is quite tricky here //Johan + +class RobustLosses(nn.Module): + def __init__( + self, + robust=False, + center_coords=False, + scale_normalize=False, + ce_weight=0.01, + local_loss=True, + local_dist=None, + smooth_mask = False, + depth_interpolation_mode = "bilinear", + mask_depth_loss = False, + relative_depth_error_threshold = 0.05, + alpha = 1., + c = 1e-3, + epe_mask_prob_th = None, + cert_only_on_consistent_depth = False, + ): + super().__init__() + if local_dist is None: + local_dist = {} + self.robust = robust # measured in pixels + self.center_coords = center_coords + self.scale_normalize = scale_normalize + self.ce_weight = ce_weight + self.local_loss = local_loss + self.local_dist = local_dist + self.smooth_mask = smooth_mask + self.depth_interpolation_mode = depth_interpolation_mode + self.mask_depth_loss = mask_depth_loss + self.relative_depth_error_threshold = relative_depth_error_threshold + self.avg_overlap = dict() + self.alpha = alpha + self.c = c + self.epe_mask_prob_th = epe_mask_prob_th + self.cert_only_on_consistent_depth = cert_only_on_consistent_depth + + def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale): + b, h,w, h,w = corr_volume.shape + inv_temp = 10 + corr_volume = corr_volume.reshape(-1, h*w, h*w) + nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2) + corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean() + + losses = { + f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(), + } + wandb.log(losses, step = romatch.GLOBAL_STEP) + return losses + + + + def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): + epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) + if scale in self.local_dist: + prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float() + if scale == 1: + pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() + wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP) + if self.epe_mask_prob_th is not None: + # if too far away from gt, certainty should be 0 + gt_cert = prob * (epe < scale * self.epe_mask_prob_th) + else: + gt_cert = prob + if self.cert_only_on_consistent_depth: + ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0]) + else: + ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert) + a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha + cs = self.c * scale + x = epe[prob > 0.99] + reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) + if not torch.any(reg_loss): + reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere + losses = { + f"{mode}_certainty_loss_{scale}": ce_loss.mean(), + f"{mode}_regression_loss_{scale}": reg_loss.mean(), + } + wandb.log(losses, step = romatch.GLOBAL_STEP) + return losses + + def forward(self, corresps, batch): + scales = list(corresps.keys()) + tot_loss = 0.0 + # scale_weights due to differences in scale for regression gradients and classification gradients + for scale in scales: + scale_corresps = corresps[scale] + scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = ( + scale_corresps["certainty"], + scale_corresps.get("flow_pre_delta"), + scale_corresps.get("delta_cls"), + scale_corresps.get("offset_scale"), + scale_corresps.get("corr_volume"), + scale_corresps.get("gm_certainty"), + scale_corresps["flow"], + scale_corresps.get("gm_flow"), + + ) + if flow_pre_delta is not None: + flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") + b, h, w, d = flow_pre_delta.shape + else: + # _ = 1 + b, _, h, w = scale_certainty.shape + gt_warp, gt_prob = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=h, + W=w, + ) + x2 = gt_warp.float() + prob = gt_prob + + if scale_gm_corr_volume is not None: + gt_warp_back, _ = get_gt_warp( + batch["im_B_depth"], + batch["im_A_depth"], + batch["T_1to2"].inverse(), + batch["K2"], + batch["K1"], + H=h, + W=w, + ) + grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device) + #fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1) + #diff = (fwd_bck - grid).norm(dim = -1) + with torch.no_grad(): + D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2)) + D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2)) + inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values) + * (D_A == D_A.min(dim=-2, keepdim = True).values) + * (D_B < 0.01) + * (D_A < 0.01)) + + gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale) + gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"] + tot_loss = tot_loss + gm_loss + elif scale_gm_flow is not None: + gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") + gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] + tot_loss = tot_loss + gm_loss + delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) + reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] + tot_loss = tot_loss + reg_loss + return tot_loss diff --git a/imcui/third_party/RoMa/romatch/models/__init__.py b/imcui/third_party/RoMa/romatch/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7650c9c7480920905e27578f175fcb5f995cc8ba --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/__init__.py @@ -0,0 +1 @@ +from .model_zoo import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/models/encoders.py b/imcui/third_party/RoMa/romatch/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..84fb54395139a2ca21860ce2c18d033ad0afb19f --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/encoders.py @@ -0,0 +1,122 @@ +from typing import Optional, Union +import torch +from torch import device +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as tvm +import gc +from romatch.utils.utils import get_autocast_params + + +class ResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None, + dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + if dilation is None: + dilation = [False,False,False] + if anti_aliased: + pass + else: + if weights is not None: + self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + else: + self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) + + self.high_res = high_res + self.freeze_bn = freeze_bn + self.early_exit = early_exit + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, self.amp, self.amp_dtype) + with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype): + net = self.net + feats = {1:x} + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x + x = net.layer2(x) + feats[8] = x + if self.early_exit: + return feats + x = net.layer3(x) + feats[16] = x + x = net.layer4(x) + feats[32] = x + return feats + + def train(self, mode=True): + super().train(mode) + if self.freeze_bn: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + +class VGG19(nn.Module): + def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, self.amp, self.amp_dtype) + with torch.autocast(device_type=autocast_device, enabled=autocast_enabled, dtype = autocast_dtype): + feats = {} + scale = 1 + for layer in self.layers: + if isinstance(layer, nn.MaxPool2d): + feats[scale] = x + scale = scale*2 + x = layer(x) + return feats + +class CNNandDinov2(nn.Module): + def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16): + super().__init__() + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu") + from .transformer import vit_large + vit_kwargs = dict(img_size= 518, + patch_size= 14, + init_values = 1.0, + ffn_layer = "mlp", + block_chunks = 0, + ) + + dinov2_vitl14 = vit_large(**vit_kwargs).eval() + dinov2_vitl14.load_state_dict(dinov2_weights) + cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} + if not use_vgg: + self.cnn = ResNet50(**cnn_kwargs) + else: + self.cnn = VGG19(**cnn_kwargs) + self.amp = amp + self.amp_dtype = amp_dtype + if self.amp: + dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) + self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP + + + def train(self, mode: bool = True): + return self.cnn.train(mode) + + def forward(self, x, upsample = False): + B,C,H,W = x.shape + feature_pyramid = self.cnn(x) + + if not upsample: + with torch.no_grad(): + if self.dinov2_vitl14[0].device != x.device: + self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) + dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype)) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14) + del dinov2_features_16 + feature_pyramid[16] = features_16 + return feature_pyramid \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/models/matcher.py b/imcui/third_party/RoMa/romatch/models/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..8108a92393f4657e2c87d75c4072a46e7d61cdd6 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/matcher.py @@ -0,0 +1,748 @@ +import os +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import warnings +from warnings import warn +from PIL import Image + +from romatch.utils import get_tuple_transform_ops +from romatch.utils.local_correlation import local_correlation +from romatch.utils.utils import cls_to_flow_refine, get_autocast_params +from romatch.utils.kde import kde + +class ConvRefiner(nn.Module): + def __init__( + self, + in_dim=6, + hidden_dim=16, + out_dim=2, + dw=False, + kernel_size=5, + hidden_blocks=3, + displacement_emb = None, + displacement_emb_dim = None, + local_corr_radius = None, + corr_in_other = None, + no_im_B_fm = False, + amp = False, + concat_logits = False, + use_bias_block_1 = True, + use_cosine_corr = False, + disable_local_corr_grad = False, + is_classifier = False, + sample_mode = "bilinear", + norm_type = nn.BatchNorm2d, + bn_momentum = 0.1, + amp_dtype = torch.float16, + ): + super().__init__() + self.bn_momentum = bn_momentum + self.block1 = self.create_block( + in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1, + ) + self.hidden_blocks = nn.Sequential( + *[ + self.create_block( + hidden_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + norm_type=norm_type, + ) + for hb in range(hidden_blocks) + ] + ) + self.hidden_blocks = self.hidden_blocks + self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) + if displacement_emb: + self.has_displacement_emb = True + self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + else: + self.has_displacement_emb = False + self.local_corr_radius = local_corr_radius + self.corr_in_other = corr_in_other + self.no_im_B_fm = no_im_B_fm + self.amp = amp + self.concat_logits = concat_logits + self.use_cosine_corr = use_cosine_corr + self.disable_local_corr_grad = disable_local_corr_grad + self.is_classifier = is_classifier + self.sample_mode = sample_mode + self.amp_dtype = amp_dtype + + def create_block( + self, + in_dim, + out_dim, + dw=False, + kernel_size=5, + bias = True, + norm_type = nn.BatchNorm2d, + ): + num_groups = 1 if not dw else in_dim + if dw: + assert ( + out_dim % in_dim == 0 + ), "outdim must be divisible by indim for depthwise" + conv1 = nn.Conv2d( + in_dim, + out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=num_groups, + bias=bias, + ) + norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) + return nn.Sequential(conv1, norm, relu, conv2) + + def forward(self, x, y, flow, scale_factor = 1, logits = None): + b,c,hs,ws = x.shape + autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, enabled=self.amp, dtype=self.amp_dtype) + with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode) + if self.has_displacement_emb: + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device), + ), indexing='ij' + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) + in_displacement = flow-im_A_coords + emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement) + if self.local_corr_radius: + if self.corr_in_other: + # Corr in other means take a kxk grid around the predicted coordinate in other image + local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow, + sample_mode = self.sample_mode) + else: + raise NotImplementedError("Local corr in own frame should not be used.") + if self.no_im_B_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) + else: + d = torch.cat((x, x_hat, emb_in_displacement), dim=1) + else: + if self.no_im_B_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat), dim=1) + if self.concat_logits: + d = torch.cat((d, logits), dim=1) + d = self.block1(d) + d = self.hidden_blocks(d) + d = self.out_conv(d.float()) + displacement, certainty = d[:, :-1], d[:, -1:] + return displacement, certainty + +class CosKernel(nn.Module): # similar to softmax kernel + def __init__(self, T, learn_temperature=False): + super().__init__() + self.learn_temperature = learn_temperature + if self.learn_temperature: + self.T = nn.Parameter(torch.tensor(T)) + else: + self.T = T + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + if self.learn_temperature: + T = self.T.abs() + 0.01 + else: + T = torch.tensor(self.T, device=c.device) + K = ((c - 1.0) / T).exp() + return K + +class GP(nn.Module): + def __init__( + self, + kernel, + T=1, + learn_temperature=False, + only_attention=False, + gp_dim=64, + basis="fourier", + covar_size=5, + only_nearest_neighbour=False, + sigma_noise=0.1, + no_cov=False, + predict_features = False, + ): + super().__init__() + self.K = kernel(T=T, learn_temperature=learn_temperature) + self.sigma_noise = sigma_noise + self.covar_size = covar_size + self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) + self.only_attention = only_attention + self.only_nearest_neighbour = only_nearest_neighbour + self.basis = basis + self.no_cov = no_cov + self.dim = gp_dim + self.predict_features = predict_features + + def get_local_cov(self, cov): + K = self.covar_size + b, h, w, h, w = cov.shape + hw = h * w + cov = F.pad(cov, 4 * (K // 2,)) # pad v_q + delta = torch.stack( + torch.meshgrid( + torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1), + indexing = 'ij'), + dim=-1, + ) + positions = torch.stack( + torch.meshgrid( + torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2), + indexing = 'ij'), + dim=-1, + ) + neighbours = positions[:, :, None, None, :] + delta[None, :, :] + points = torch.arange(hw)[:, None].expand(hw, K**2) + local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ + :, + points.flatten(), + neighbours[..., 0].flatten(), + neighbours[..., 1].flatten(), + ].reshape(b, h, w, K**2) + return local_cov + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def project_to_basis(self, x): + if self.basis == "fourier": + return torch.cos(8 * math.pi * self.pos_conv(x)) + elif self.basis == "linear": + return self.pos_conv(x) + else: + raise ValueError( + "No other bases other than fourier and linear currently im_Bed in public release" + ) + + def get_pos_enc(self, y): + b, c, h, w = y.shape + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), + ), + indexing = 'ij' + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.project_to_basis(coarse_coords) + return coarse_embedded_coords + + def forward(self, x, y, **kwargs): + b, c, h1, w1 = x.shape + b, c, h2, w2 = y.shape + f = self.get_pos_enc(y) + b, d, h2, w2 = f.shape + x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f) + K_xx = self.K(x, x) + K_yy = self.K(y, y) + K_xy = self.K(x, y) + K_yx = K_xy.permute(0, 2, 1) + sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] + with warnings.catch_warnings(): + K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) + + mu_x = K_xy.matmul(K_yy_inv.matmul(f)) + mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) + if not self.no_cov: + cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) + cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + local_cov_x = self.get_local_cov(cov_x) + local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") + gp_feats = torch.cat((mu_x, local_cov_x), dim=1) + else: + gp_feats = mu_x + return gp_feats + +class Decoder(nn.Module): + def __init__( + self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None, + num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0, + flow_upsample_mode = "bilinear", amp_dtype = torch.float16, + ): + super().__init__() + self.embedding_decoder = embedding_decoder + self.num_refinement_steps_per_scale = num_refinement_steps_per_scale + self.gps = gps + self.proj = proj + self.conv_refiner = conv_refiner + self.detach = detach + if pos_embeddings is None: + self.pos_embeddings = {} + else: + self.pos_embeddings = pos_embeddings + if scales == "all": + self.scales = ["32", "16", "8", "4", "2", "1"] + else: + self.scales = scales + self.warp_noise_std = warp_noise_std + self.refine_init = 4 + self.displacement_dropout_p = displacement_dropout_p + self.gm_warp_dropout_p = gm_warp_dropout_p + self.flow_upsample_mode = flow_upsample_mode + self.amp_dtype = amp_dtype + + def get_placeholder_flow(self, b, h, w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ), + indexing = 'ij' + ) + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + return coarse_coords + + def get_positional_embedding(self, b, h ,w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ), + indexing = 'ij' + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.pos_embedding(coarse_coords) + return coarse_embedded_coords + + def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1): + coarse_scales = self.embedding_decoder.scales() + all_scales = self.scales if not upsample else ["8", "4", "2", "1"] + sizes = {scale: f1[scale].shape[-2:] for scale in f1} + h, w = sizes[1] + b = f1[1].shape[0] + device = f1[1].device + coarsest_scale = int(all_scales[0]) + old_stuff = torch.zeros( + b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + ) + corresps = {} + if not upsample: + flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) + certainty = 0.0 + else: + flow = F.interpolate( + flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + certainty = F.interpolate( + certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + displacement = 0.0 + for new_scale in all_scales: + ins = int(new_scale) + corresps[ins] = {} + f1_s, f2_s = f1[ins], f2[ins] + if new_scale in self.proj: + autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(f1_s.device, str(f1_s)=='cuda', self.amp_dtype) + with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype): + if not autocast_enabled: + f1_s, f2_s = f1_s.to(torch.float32), f2_s.to(torch.float32) + f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) + + if ins in coarse_scales: + old_stuff = F.interpolate( + old_stuff, size=sizes[ins], mode="bilinear", align_corners=False + ) + gp_posterior = self.gps[new_scale](f1_s, f2_s) + gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder( + gp_posterior, f1_s, old_stuff, new_scale + ) + + if self.embedding_decoder.is_classifier: + flow = cls_to_flow_refine( + gm_warp_or_cls, + ).permute(0,3,1,2) + corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None + else: + corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None + flow = gm_warp_or_cls.detach() + + if new_scale in self.conv_refiner: + corresps[ins].update({"flow_pre_delta": flow}) if self.training else None + delta_flow, delta_certainty = self.conv_refiner[new_scale]( + f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty, + ) + corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None + displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w), + delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,) + flow = flow + displacement + certainty = ( + certainty + delta_certainty + ) # predict both certainty and displacement + corresps[ins].update({ + "certainty": certainty, + "flow": flow, + }) + if new_scale != "1": + flow = F.interpolate( + flow, + size=sizes[ins // 2], + mode=self.flow_upsample_mode, + ) + certainty = F.interpolate( + certainty, + size=sizes[ins // 2], + mode=self.flow_upsample_mode, + ) + if self.detach: + flow = flow.detach() + certainty = certainty.detach() + #torch.cuda.empty_cache() + return corresps + + +class RegressionMatcher(nn.Module): + def __init__( + self, + encoder, + decoder, + h=448, + w=448, + sample_mode = "threshold_balanced", + upsample_preds = False, + symmetric = False, + name = None, + attenuate_cert = None, + ): + super().__init__() + self.attenuate_cert = attenuate_cert + self.encoder = encoder + self.decoder = decoder + self.name = name + self.w_resized = w + self.h_resized = h + self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) + self.sample_mode = sample_mode + self.upsample_preds = upsample_preds + self.upsample_res = (14*16*6, 14*16*6) + self.symmetric = symmetric + self.sample_thresh = 0.05 + + def get_output_resolution(self): + if not self.upsample_preds: + return self.h_resized, self.w_resized + else: + return self.upsample_res + + def extract_backbone_features(self, batch, batched = True, upsample = False): + x_q = batch["im_A"] + x_s = batch["im_B"] + if batched: + X = torch.cat((x_q, x_s), dim = 0) + feature_pyramid = self.encoder(X, upsample = upsample) + else: + feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample) + return feature_pyramid + + def sample( + self, + matches, + certainty, + num=10000, + ): + if "threshold" in self.sample_mode: + upper_thresh = self.sample_thresh + certainty = certainty.clone() + certainty[certainty > upper_thresh] = 1 + matches, certainty = ( + matches.reshape(-1, 4), + certainty.reshape(-1), + ) + expansion_factor = 4 if "balanced" in self.sample_mode else 1 + good_samples = torch.multinomial(certainty, + num_samples = min(expansion_factor*num, len(certainty)), + replacement=False) + good_matches, good_certainty = matches[good_samples], certainty[good_samples] + if "balanced" not in self.sample_mode: + return good_matches, good_certainty + density = kde(good_matches, std=0.1) + p = 1 / (density+1) + p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial(p, + num_samples = min(num,len(good_certainty)), + replacement=False) + return good_matches[balanced_samples], good_certainty[balanced_samples] + + def forward(self, batch, batched = True, upsample = False, scale_factor = 1): + feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample) + if batched: + f_q_pyramid = { + scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() + } + f_s_pyramid = { + scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() + } + else: + f_q_pyramid, f_s_pyramid = feature_pyramid + corresps = self.decoder(f_q_pyramid, + f_s_pyramid, + upsample = upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor) + + return corresps + + def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1): + feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample) + f_q_pyramid = feature_pyramid + f_s_pyramid = { + scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0) + for scale, f_scale in feature_pyramid.items() + } + corresps = self.decoder(f_q_pyramid, + f_s_pyramid, + upsample = upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor) + return corresps + + def conf_from_fb_consistency(self, flow_forward, flow_backward, th = 2): + # assumes that flow forward is of shape (..., H, W, 2) + has_batch = False + if len(flow_forward.shape) == 3: + flow_forward, flow_backward = flow_forward[None], flow_backward[None] + else: + has_batch = True + H,W = flow_forward.shape[-3:-1] + th_n = 2 * th / max(H,W) + coords = torch.stack(torch.meshgrid( + torch.linspace(-1 + 1 / W, 1 - 1 / W, W), + torch.linspace(-1 + 1 / H, 1 - 1 / H, H), indexing = "xy"), + dim = -1).to(flow_forward.device) + coords_fb = F.grid_sample( + flow_backward.permute(0, 3, 1, 2), + flow_forward, + align_corners=False, mode="bilinear").permute(0, 2, 3, 1) + diff = (coords - coords_fb).norm(dim=-1) + in_th = (diff < th_n).float() + if not has_batch: + in_th = in_th[0] + return in_th + + def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None): + if coords.shape[-1] == 2: + return self._to_pixel_coordinates(coords, H_A, W_A) + + if isinstance(coords, (list, tuple)): + kpts_A, kpts_B = coords[0], coords[1] + else: + kpts_A, kpts_B = coords[...,:2], coords[...,2:] + return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B) + + def _to_pixel_coordinates(self, coords, H, W): + kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1) + return kpts + + def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B): + if isinstance(coords, (list, tuple)): + kpts_A, kpts_B = coords[0], coords[1] + else: + kpts_A, kpts_B = coords[...,:2], coords[...,2:] + kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1) + kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1) + return kpts_A, kpts_B + + def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False): + x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT + cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0] + D = torch.cdist(x_A_to_B, x_B) + inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True) + + if return_tuple: + if return_inds: + return inds_A, inds_B + else: + return x_A[inds_A], x_B[inds_B] + else: + if return_inds: + return torch.cat((inds_A, inds_B),dim=-1) + else: + return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1) + + @torch.inference_mode() + def match( + self, + im_A_path, + im_B_path, + *args, + batched=False, + device = None, + ): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if isinstance(im_A_path, (str, os.PathLike)): + im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") + else: + im_A, im_B = im_A_path, im_B_path + + symmetric = self.symmetric + self.train(False) + with torch.no_grad(): + if not batched: + b = 1 + w, h = im_A.size + w2, h2 = im_B.size + # Get images in good format + ws = self.w_resized + hs = self.h_resized + + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True, clahe = False + ) + im_A, im_B = test_transform((im_A, im_B)) + batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} + else: + b, c, h, w = im_A.shape + b, c, h2, w2 = im_B.shape + assert w == w2 and h == h2, "For batched images we assume same size" + batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)} + if h != self.h_resized or self.w_resized != w: + warn("Model resolution and batch resolution differ, may produce unexpected results") + hs, ws = h, w + finest_scale = 1 + # Run matcher + if symmetric: + corresps = self.forward_symmetric(batch) + else: + corresps = self.forward(batch, batched = True) + + if self.upsample_preds: + hs, ws = self.upsample_res + + if self.attenuate_cert: + low_res_certainty = F.interpolate( + corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + cert_clamp = 0 + factor = 0.5 + low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + + if self.upsample_preds: + finest_corresps = corresps[finest_scale] + torch.cuda.empty_cache() + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True + ) + im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB'))) + im_A, im_B = im_A[None].to(device), im_B[None].to(device) + scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized)) + batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} + if symmetric: + corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor) + else: + corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor) + + im_A_to_im_B = corresps[finest_scale]["flow"] + certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0) + if finest_scale != 1: + im_A_to_im_B = F.interpolate( + im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" + ) + certainty = F.interpolate( + certainty, size=(hs, ws), align_corners=False, mode="bilinear" + ) + im_A_to_im_B = im_A_to_im_B.permute( + 0, 2, 3, 1 + ) + # Create im_A meshgrid + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ), + indexing = 'ij' + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) + certainty = certainty.sigmoid() # logits -> probs + im_A_coords = im_A_coords.permute(0, 2, 3, 1) + if (im_A_to_im_B.abs() > 1).any() and True: + wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 + certainty[wrong[:,None]] = 0 + im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) + if symmetric: + A_to_B, B_to_A = im_A_to_im_B.chunk(2) + q_warp = torch.cat((im_A_coords, A_to_B), dim=-1) + im_B_coords = im_A_coords + s_warp = torch.cat((B_to_A, im_B_coords), dim=-1) + warp = torch.cat((q_warp, s_warp),dim=2) + certainty = torch.cat(certainty.chunk(2), dim=3) + else: + warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) + if batched: + return ( + warp, + certainty[:, 0] + ) + else: + return ( + warp[0], + certainty[0, 0], + ) + + def visualize_warp(self, warp, certainty, im_A = None, im_B = None, + im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None, unnormalize = False): + #assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)" + H,W2,_ = warp.shape + W = W2//2 if symmetric else W2 + if im_A is None: + from PIL import Image + im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") + if not isinstance(im_A, torch.Tensor): + im_A = im_A.resize((W,H)) + im_B = im_B.resize((W,H)) + x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1) + if symmetric: + x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1) + else: + if symmetric: + x_A = im_A + x_B = im_B + im_A_transfer_rgb = F.grid_sample( + x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + if symmetric: + im_B_transfer_rgb = F.grid_sample( + x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + else: + warp_im = im_A_transfer_rgb + white_im = torch.ones((H, W), device = device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + if save_path is not None: + from romatch.utils import tensor_to_pil + tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path) + return vis_im diff --git a/imcui/third_party/RoMa/romatch/models/model_zoo/__init__.py b/imcui/third_party/RoMa/romatch/models/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0470ca3f0c3b8064b1b2f01663dfb13742d7a10 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/model_zoo/__init__.py @@ -0,0 +1,73 @@ +from typing import Union +import torch +from .roma_models import roma_model, tiny_roma_v1_model + +weight_urls = { + "romatch": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", + }, + "tiny_roma_v1": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth", + }, + "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D +} + +def tiny_roma_v1_outdoor(device, weights = None, xfeat = None): + if weights is None: + weights = torch.hub.load_state_dict_from_url( + weight_urls["tiny_roma_v1"]["outdoor"], + map_location=device) + if xfeat is None: + xfeat = torch.hub.load( + 'verlab/accelerated_features', + 'XFeat', + pretrained = True, + top_k = 4096).net + + return tiny_roma_v1_model(weights = weights, xfeat = xfeat).to(device) + +def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16): + if isinstance(coarse_res, int): + coarse_res = (coarse_res, coarse_res) + if isinstance(upsample_res, int): + upsample_res = (upsample_res, upsample_res) + + if str(device) == 'cpu': + amp_dtype = torch.float32 + + assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone" + assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone" + + if weights is None: + weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["outdoor"], + map_location=device) + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], + map_location=device) + model = roma_model(resolution=coarse_res, upsample_preds=True, + weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype) + model.upsample_res = upsample_res + print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}") + return model + +def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16): + if isinstance(coarse_res, int): + coarse_res = (coarse_res, coarse_res) + if isinstance(upsample_res, int): + upsample_res = (upsample_res, upsample_res) + + assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone" + assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone" + + if weights is None: + weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["indoor"], + map_location=device) + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], + map_location=device) + model = roma_model(resolution=coarse_res, upsample_preds=True, + weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype) + model.upsample_res = upsample_res + print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}") + return model diff --git a/imcui/third_party/RoMa/romatch/models/model_zoo/roma_models.py b/imcui/third_party/RoMa/romatch/models/model_zoo/roma_models.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8a08fc26d760a09f048bdd98bf8a8fffc8202c --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/model_zoo/roma_models.py @@ -0,0 +1,170 @@ +import warnings +import torch.nn as nn +import torch +from romatch.models.matcher import * +from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention +from romatch.models.encoders import * +from romatch.models.tiny import TinyRoMa + +def tiny_roma_v1_model(weights = None, freeze_xfeat=False, exact_softmax=False, xfeat = None): + model = TinyRoMa( + xfeat = xfeat, + freeze_xfeat=freeze_xfeat, + exact_softmax=exact_softmax) + if weights is not None: + model.load_state_dict(weights) + return model + +def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs): + # romatch weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters + #torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful + #torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') + gp_dim = 512 + feat_dim = 512 + decoder_dim = gp_dim + feat_dim + cls_to_coord_res = 64 + coordinate_decoder = TransformerDecoder( + nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), + decoder_dim, + cls_to_coord_res**2 + 1, + is_classifier=True, + amp = True, + pos_enc = False,) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + disable_local_corr_grad = True + + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128+(2*7+1)**2, + 2 * 512+128+(2*7+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "8": ConvRefiner( + 2 * 512+64+(2*3+1)**2, + 2 * 512+64+(2*3+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "4": ConvRefiner( + 2 * 256+32+(2*2+1)**2, + 2 * 256+32+(2*2+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "1": ConvRefiner( + 2 * 9 + 6, + 24, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks = hidden_blocks, + displacement_emb = displacement_emb, + displacement_emb_dim = 6, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"16": gp16}) + proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) + proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) + proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) + proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) + proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) + proj = nn.ModuleDict({ + "16": proj16, + "8": proj8, + "4": proj4, + "2": proj2, + "1": proj1, + }) + displacement_dropout_p = 0.0 + gm_warp_dropout_p = 0.0 + decoder = Decoder(coordinate_decoder, + gps, + proj, + conv_refiner, + detach=True, + scales=["16", "8", "4", "2", "1"], + displacement_dropout_p = displacement_dropout_p, + gm_warp_dropout_p = gm_warp_dropout_p) + + encoder = CNNandDinov2( + cnn_kwargs = dict( + pretrained=False, + amp = True), + amp = True, + use_vgg = True, + dinov2_weights = dinov2_weights, + amp_dtype=amp_dtype, + ) + h,w = resolution + symmetric = True + attenuate_cert = True + sample_mode = "threshold_balanced" + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, + symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device) + matcher.load_state_dict(weights) + return matcher diff --git a/imcui/third_party/RoMa/romatch/models/tiny.py b/imcui/third_party/RoMa/romatch/models/tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..88f44af6c9a6255831734d096167724f89d040ce --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/tiny.py @@ -0,0 +1,304 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import torch +from pathlib import Path +import math +import numpy as np + +from torch import nn +from PIL import Image +from torchvision.transforms import ToTensor +from romatch.utils.kde import kde + +class BasicLayer(nn.Module): + """ + Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True): + super().__init__() + self.layer = nn.Sequential( + nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias), + nn.BatchNorm2d(out_channels, affine=False), + nn.ReLU(inplace = True) if relu else nn.Identity() + ) + + def forward(self, x): + return self.layer(x) + +class TinyRoMa(nn.Module): + """ + Implementation of architecture described in + "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." + """ + + def __init__(self, xfeat = None, + freeze_xfeat = True, + sample_mode = "threshold_balanced", + symmetric = False, + exact_softmax = False): + super().__init__() + del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher + if freeze_xfeat: + xfeat.train(False) + self.xfeat = [xfeat]# hide params from ddp + else: + self.xfeat = nn.ModuleList([xfeat]) + self.freeze_xfeat = freeze_xfeat + match_dim = 256 + self.coarse_matcher = nn.Sequential( + BasicLayer(64+64+2, match_dim,), + BasicLayer(match_dim, match_dim,), + BasicLayer(match_dim, match_dim,), + BasicLayer(match_dim, match_dim,), + nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0)) + fine_match_dim = 64 + self.fine_matcher = nn.Sequential( + BasicLayer(24+24+2, fine_match_dim,), + BasicLayer(fine_match_dim, fine_match_dim,), + BasicLayer(fine_match_dim, fine_match_dim,), + BasicLayer(fine_match_dim, fine_match_dim,), + nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),) + self.sample_mode = sample_mode + self.sample_thresh = 0.05 + self.symmetric = symmetric + self.exact_softmax = exact_softmax + + @property + def device(self): + return self.fine_matcher[-1].weight.device + + def preprocess_tensor(self, x): + """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """ + H, W = x.shape[-2:] + _H, _W = (H//32) * 32, (W//32) * 32 + rh, rw = H/_H, W/_W + + x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False) + return x, rh, rw + + def forward_single(self, x): + with torch.inference_mode(self.freeze_xfeat or not self.training): + xfeat = self.xfeat[0] + with torch.no_grad(): + x = x.mean(dim=1, keepdim = True) + x = xfeat.norm(x) + + #main backbone + x1 = xfeat.block1(x) + x2 = xfeat.block2(x1 + xfeat.skip1(x)) + x3 = xfeat.block3(x2) + x4 = xfeat.block4(x3) + x5 = xfeat.block5(x4) + x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear') + x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear') + feats = xfeat.block_fusion( x3 + x4 + x5 ) + if self.freeze_xfeat: + return x2.clone(), feats.clone() + return x2, feats + + def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None): + if coords.shape[-1] == 2: + return self._to_pixel_coordinates(coords, H_A, W_A) + + if isinstance(coords, (list, tuple)): + kpts_A, kpts_B = coords[0], coords[1] + else: + kpts_A, kpts_B = coords[...,:2], coords[...,2:] + return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B) + + def _to_pixel_coordinates(self, coords, H, W): + kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1) + return kpts + + def pos_embed(self, corr_volume: torch.Tensor): + B, H1, W1, H0, W0 = corr_volume.shape + grid = torch.stack( + torch.meshgrid( + torch.linspace(-1+1/W1,1-1/W1, W1), + torch.linspace(-1+1/H1,1-1/H1, H1), + indexing = "xy"), + dim = -1).float().to(corr_volume).reshape(H1*W1, 2) + down = 4 + if not self.training and not self.exact_softmax: + grid_lr = torch.stack( + torch.meshgrid( + torch.linspace(-1+down/W1,1-down/W1, W1//down), + torch.linspace(-1+down/H1,1-down/H1, H1//down), + indexing = "xy"), + dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2) + cv = corr_volume + best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W + P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1) + pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr) + pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2) + #print("hej") + else: + P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W + pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid) + return pos_embeddings + + def visualize_warp(self, warp, certainty, im_A = None, im_B = None, + im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False): + device = warp.device + H,W2,_ = warp.shape + W = W2//2 if symmetric else W2 + if im_A is None: + from PIL import Image + im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") + if not isinstance(im_A, torch.Tensor): + im_A = im_A.resize((W,H)) + im_B = im_B.resize((W,H)) + x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1) + if symmetric: + x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1) + else: + if symmetric: + x_A = im_A + x_B = im_B + im_A_transfer_rgb = F.grid_sample( + x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + if symmetric: + im_B_transfer_rgb = F.grid_sample( + x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + else: + warp_im = im_A_transfer_rgb + white_im = torch.ones((H, W), device = device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + if save_path is not None: + from romatch.utils import tensor_to_pil + tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path) + return vis_im + + def corr_volume(self, feat0, feat1): + """ + input: + feat0 -> torch.Tensor(B, C, H, W) + feat1 -> torch.Tensor(B, C, H, W) + return: + corr_volume -> torch.Tensor(B, H, W, H, W) + """ + B, C, H0, W0 = feat0.shape + B, C, H1, W1 = feat1.shape + feat0 = feat0.view(B, C, H0*W0) + feat1 = feat1.view(B, C, H1*W1) + corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16 + return corr_volume + + @torch.inference_mode() + def match_from_path(self, im0_path, im1_path): + device = self.device + im0 = ToTensor()(Image.open(im0_path))[None].to(device) + im1 = ToTensor()(Image.open(im1_path))[None].to(device) + return self.match(im0, im1, batched = False) + + @torch.inference_mode() + def match(self, im0, im1, *args, batched = True): + # stupid + if isinstance(im0, (str, Path)): + return self.match_from_path(im0, im1) + elif isinstance(im0, Image.Image): + batched = False + device = self.device + im0 = ToTensor()(im0)[None].to(device) + im1 = ToTensor()(im1)[None].to(device) + + B,C,H0,W0 = im0.shape + B,C,H1,W1 = im1.shape + self.train(False) + corresps = self.forward({"im_A":im0, "im_B":im1}) + #return 1,1 + flow = F.interpolate( + corresps[4]["flow"], + size = (H0, W0), + mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2) + grid = torch.stack( + torch.meshgrid( + torch.linspace(-1+1/W0,1-1/W0, W0), + torch.linspace(-1+1/H0,1-1/H0, H0), + indexing = "xy"), + dim = -1).float().to(flow.device).expand(B, H0, W0, 2) + + certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False) + warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid() + if batched: + return warp, cert + else: + return warp[0], cert[0] + + def sample( + self, + matches, + certainty, + num=5_000, + ): + H,W,_ = matches.shape + if "threshold" in self.sample_mode: + upper_thresh = self.sample_thresh + certainty = certainty.clone() + certainty[certainty > upper_thresh] = 1 + matches, certainty = ( + matches.reshape(-1, 4), + certainty.reshape(-1), + ) + expansion_factor = 4 if "balanced" in self.sample_mode else 1 + good_samples = torch.multinomial(certainty, + num_samples = min(expansion_factor*num, len(certainty)), + replacement=False) + good_matches, good_certainty = matches[good_samples], certainty[good_samples] + if "balanced" not in self.sample_mode: + return good_matches, good_certainty + use_half = True if matches.device.type == "cuda" else False + down = 1 if matches.device.type == "cuda" else 8 + density = kde(good_matches, std=0.1, half = use_half, down = down) + p = 1 / (density+1) + p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial(p, + num_samples = min(num,len(good_certainty)), + replacement=False) + return good_matches[balanced_samples], good_certainty[balanced_samples] + + + def forward(self, batch): + """ + input: + x -> torch.Tensor(B, C, H, W) grayscale or rgb images + return: + + """ + im0 = batch["im_A"] + im1 = batch["im_B"] + corresps = {} + im0, rh0, rw0 = self.preprocess_tensor(im0) + im1, rh1, rw1 = self.preprocess_tensor(im1) + B, C, H0, W0 = im0.shape + B, C, H1, W1 = im1.shape + to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None] + + if im0.shape[-2:] == im1.shape[-2:]: + x = torch.cat([im0, im1], dim=0) + x = self.forward_single(x) + feats_x0_c, feats_x1_c = x[1].chunk(2) + feats_x0_f, feats_x1_f = x[0].chunk(2) + else: + feats_x0_f, feats_x0_c = self.forward_single(im0) + feats_x1_f, feats_x1_c = self.forward_single(im1) + corr_volume = self.corr_volume(feats_x0_c, feats_x1_c) + coarse_warp = self.pos_embed(corr_volume) + coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1) + feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) + coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1)) + coarse_matches = coarse_matches + coarse_matches_delta * to_normalized + corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]} + coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False) + coarse_matches_up_detach = coarse_matches_up.detach()#note the detach + feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False) + fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1)) + fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized + corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]} + return corresps \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/models/transformer/__init__.py b/imcui/third_party/RoMa/romatch/models/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..983f03ccc51cdbcef6166a160fe50652a81418d7 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/__init__.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from romatch.utils.utils import get_grid, get_autocast_params +from .layers.block import Block +from .layers.attention import MemEffAttention +from .dinov2 import vit_large + +class TransformerDecoder(nn.Module): + def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, + amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.blocks = blocks + self.to_out = nn.Linear(hidden_dim, out_dim) + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self._scales = [16] + self.is_classifier = is_classifier + self.amp = amp + self.amp_dtype = amp_dtype + self.pos_enc = pos_enc + self.learned_embeddings = learned_embeddings + if self.learned_embeddings: + self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim)))) + + def scales(self): + return self._scales.copy() + + def forward(self, gp_posterior, features, old_stuff, new_scale): + autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(gp_posterior.device, enabled=self.amp, dtype=self.amp_dtype) + with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype): + B,C,H,W = gp_posterior.shape + x = torch.cat((gp_posterior, features), dim = 1) + B,C,H,W = x.shape + grid = get_grid(B, H, W, x.device).reshape(B,H*W,2) + if self.learned_embeddings: + pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C) + else: + pos_enc = 0 + tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc + z = self.blocks(tokens) + out = self.to_out(z) + out = out.permute(0,2,1).reshape(B, self.out_dim, H, W) + warp, certainty = out[:, :-1], out[:, -1:] + return warp, certainty, None + + diff --git a/imcui/third_party/RoMa/romatch/models/transformer/dinov2.py b/imcui/third_party/RoMa/romatch/models/transformer/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..b556c63096d17239c8603d5fe626c331963099fd --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/dinov2.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + for param in self.parameters(): + param.requires_grad = False + + @property + def device(self): + return self.cls_token.device + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/__init__.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9b0c94b40967dfdff4f261c127cbd21328c905 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/block.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/dino_head.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/dino_head.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/drop_path.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/mlp.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/imcui/third_party/RoMa/romatch/train/__init__.py b/imcui/third_party/RoMa/romatch/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/train/__init__.py @@ -0,0 +1 @@ +from .train import train_k_epochs diff --git a/imcui/third_party/RoMa/romatch/train/train.py b/imcui/third_party/RoMa/romatch/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb02ed1e816fd39f174f76ec15bce49ae2a3da8 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/train/train.py @@ -0,0 +1,102 @@ +from tqdm import tqdm +from romatch.utils.utils import to_cuda +import romatch +import torch +import wandb + +def log_param_statistics(named_parameters, norm_type = 2): + named_parameters = list(named_parameters) + grads = [p.grad for n, p in named_parameters if p.grad is not None] + weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None] + names = [n for n,p in named_parameters if p.grad is not None] + param_norm = torch.stack(weight_norms).norm(p=norm_type) + device = grads[0].device + grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]) + nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms) + nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf] + total_grad_norm = torch.norm(grad_norms, norm_type) + if torch.any(nans_or_infs): + print(f"These params have nan or inf grads: {nan_inf_names}") + wandb.log({"grad_norm": total_grad_norm.item()}, step = romatch.GLOBAL_STEP) + wandb.log({"param_norm": param_norm.item()}, step = romatch.GLOBAL_STEP) + +def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs): + optimizer.zero_grad() + out = model(train_batch) + l = objective(out, train_batch) + grad_scaler.scale(l).backward() + grad_scaler.unscale_(optimizer) + log_param_statistics(model.named_parameters()) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be? + grad_scaler.step(optimizer) + grad_scaler.update() + wandb.log({"grad_scale": grad_scaler._scale.item()}, step = romatch.GLOBAL_STEP) + if grad_scaler._scale < 1.: + grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale) + romatch.GLOBAL_STEP = romatch.GLOBAL_STEP + romatch.STEP_SIZE # increment global step + return {"train_out": out, "train_loss": l.item()} + + +def train_k_steps( + n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, pbar_n_seconds = 1, +): + for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or romatch.RANK > 0, mininterval=pbar_n_seconds): + batch = next(dataloader) + model.train(True) + batch = to_cuda(batch) + train_step( + train_batch=batch, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + grad_scaler=grad_scaler, + n=n, + grad_clip_norm = grad_clip_norm, + ) + if ema_model is not None: + ema_model.update() + if warmup is not None: + with warmup.dampening(): + lr_scheduler.step() + else: + lr_scheduler.step() + [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())] + + +def train_epoch( + dataloader=None, + model=None, + objective=None, + optimizer=None, + lr_scheduler=None, + epoch=None, +): + model.train(True) + print(f"At epoch {epoch}") + for batch in tqdm(dataloader, mininterval=5.0): + batch = to_cuda(batch) + train_step( + train_batch=batch, model=model, objective=objective, optimizer=optimizer + ) + lr_scheduler.step() + return { + "model": model, + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + "epoch": epoch, + } + + +def train_k_epochs( + start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler +): + for epoch in range(start_epoch, end_epoch + 1): + train_epoch( + dataloader=dataloader, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + ) diff --git a/imcui/third_party/RoMa/romatch/utils/__init__.py b/imcui/third_party/RoMa/romatch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2709f5e586150289085a4e2cbd458bc443fab7f3 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/utils/__init__.py @@ -0,0 +1,16 @@ +from .utils import ( + pose_auc, + get_pose, + compute_relative_pose, + compute_pose_error, + estimate_pose, + estimate_pose_uncalibrated, + rotate_intrinsic, + get_tuple_transform_ops, + get_depth_tuple_transform_ops, + warp_kpts, + numpy_to_pil, + tensor_to_pil, + recover_pose, + signed_left_to_right_epipolar_distance, +) diff --git a/imcui/third_party/RoMa/romatch/utils/kde.py b/imcui/third_party/RoMa/romatch/utils/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..46ed2e5e106bbca93e703f39f3ad3af350666e34 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/utils/kde.py @@ -0,0 +1,13 @@ +import torch + + +def kde(x, std = 0.1, half = True, down = None): + # use a gaussian kernel to estimate density + if half: + x = x.half() # Do it in half precision TODO: remove hardcoding + if down is not None: + scores = (-torch.cdist(x,x[::down])**2/(2*std**2)).exp() + else: + scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + density = scores.sum(dim=-1) + return density \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/utils/local_correlation.py b/imcui/third_party/RoMa/romatch/utils/local_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1322a20bf82d0331159f958241cb87f75f4e21 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/utils/local_correlation.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F + +def local_correlation( + feature0, + feature1, + local_radius, + padding_mode="zeros", + flow = None, + sample_mode = "bilinear", +): + r = local_radius + K = (2*r+1)**2 + B, c, h, w = feature0.size() + corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype) + if flow is None: + # If flow is None, assume feature0 and feature1 are aligned + coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device), + ), + indexing = 'ij' + ) + coords = torch.stack((coords[1], coords[0]), dim=-1)[ + None + ].expand(B, h, w, 2) + else: + coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + local_window = torch.meshgrid( + ( + torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device), + torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device), + ), + indexing = 'ij' + ) + local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ + None + ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2) + for _ in range(B): + with torch.no_grad(): + local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2) + window_feature = F.grid_sample( + feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, # + ) + window_feature = window_feature.reshape(c,h,w,(2*r+1)**2) + corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1) + return corr diff --git a/imcui/third_party/RoMa/romatch/utils/transforms.py b/imcui/third_party/RoMa/romatch/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6476bd816a31df36f7d1b5417853637b65474b --- /dev/null +++ b/imcui/third_party/RoMa/romatch/utils/transforms.py @@ -0,0 +1,118 @@ +from typing import Dict +import numpy as np +import torch +import kornia.augmentation as K +from kornia.geometry.transform import warp_perspective + +# Adapted from Kornia +class GeometricSequential: + def __init__(self, *transforms, align_corners=True) -> None: + self.transforms = transforms + self.align_corners = align_corners + + def __call__(self, x, mode="bilinear"): + b, c, h, w = x.shape + M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) + for t in self.transforms: + if np.random.rand() < t.p: + M = M.matmul( + t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None) + ) + return ( + warp_perspective( + x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners + ), + M, + ) + + def apply_transform(self, x, M, mode="bilinear"): + b, c, h, w = x.shape + return warp_perspective( + x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode + ) + + +class RandomPerspective(K.RandomPerspective): + def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: + distortion_scale = torch.as_tensor( + self.distortion_scale, device=self._device, dtype=self._dtype + ) + return self.random_perspective_generator( + batch_shape[0], + batch_shape[-2], + batch_shape[-1], + distortion_scale, + self.same_on_batch, + self.device, + self.dtype, + ) + + def random_perspective_generator( + self, + batch_size: int, + height: int, + width: int, + distortion_scale: torch.Tensor, + same_on_batch: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Dict[str, torch.Tensor]: + r"""Get parameters for ``perspective`` for a random perspective transform. + + Args: + batch_size (int): the tensor batch size. + height (int) : height of the image. + width (int): width of the image. + distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. + same_on_batch (bool): apply the same transformation across the batch. Default: False. + device (torch.device): the device on which the random numbers will be generated. Default: cpu. + dtype (torch.dtype): the data type of the generated random numbers. Default: float32. + + Returns: + params Dict[str, torch.Tensor]: parameters to be passed for transformation. + - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). + - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). + + Note: + The generated random numbers are not reproducible across different devices and dtypes. + """ + if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): + raise AssertionError( + f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." + ) + if not ( + type(height) is int and height > 0 and type(width) is int and width > 0 + ): + raise AssertionError( + f"'height' and 'width' must be integers. Got {height}, {width}." + ) + + start_points: torch.Tensor = torch.tensor( + [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], + device=distortion_scale.device, + dtype=distortion_scale.dtype, + ).expand(batch_size, -1, -1) + + # generate random offset not larger than half of the image + fx = distortion_scale * width / 2 + fy = distortion_scale * height / 2 + + factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) + offset = (torch.rand_like(start_points) - 0.5) * 2 + end_points = start_points + factor * offset + + return dict(start_points=start_points, end_points=end_points) + + + +class RandomErasing: + def __init__(self, p = 0., scale = 0.) -> None: + self.p = p + self.scale = scale + self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p) + def __call__(self, image, depth): + if self.p > 0: + image = self.random_eraser(image) + depth = self.random_eraser(depth, params=self.random_eraser._params) + return image, depth + \ No newline at end of file diff --git a/imcui/third_party/RoMa/romatch/utils/utils.py b/imcui/third_party/RoMa/romatch/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c522b16b7020779a0ea58b28973f2f609145838 --- /dev/null +++ b/imcui/third_party/RoMa/romatch/utils/utils.py @@ -0,0 +1,654 @@ +import warnings +import numpy as np +import cv2 +import math +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torch.nn.functional as F +from PIL import Image +import kornia + +def recover_pose(E, kpts0, kpts1, K0, K1, mask): + best_num_inliers = 0 + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + + +# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + +def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + method = cv2.USAC_ACCURATE + F, mask = cv2.findFundamentalMat( + kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000 + ) + E = K1.T@F@K0 + ret = None + if E is not None: + best_num_inliers = 0 + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + +def unnormalize_coords(x_n,h,w): + x = torch.stack( + (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + return x + + +def rotate_intrinsic(K, n): + base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + rot = np.linalg.matrix_power(base_rot, n) + return rot @ K + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + + +# From Patch2Pix https://github.com/GrumpyZhou/patch2pix +def get_depth_tuple_transform_ops_nearest_exact(resize=None): + ops = [] + if resize: + ops.append(TupleResizeNearestExact(resize)) + return TupleCompose(ops) + +def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) + return TupleCompose(ops) + + +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None): + ops = [] + if resize: + ops.append(TupleResize(resize)) + ops.append(TupleToTensorScaled()) + if normalize: + ops.append( + TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) # Imagenet mean/std + return TupleCompose(ops) + +class ToTensorScaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" + + def __call__(self, im): + if not isinstance(im, torch.Tensor): + im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) + im /= 255.0 + return torch.from_numpy(im) + else: + return im + + def __repr__(self): + return "ToTensorScaled(./255)" + + +class TupleToTensorScaled(object): + def __init__(self): + self.to_tensor = ToTensorScaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorScaled(./255)" + + +class ToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __call__(self, im): + return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) + + def __repr__(self): + return "ToTensorUnscaled()" + + +class TupleToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __init__(self): + self.to_tensor = ToTensorUnscaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorUnscaled()" + +class TupleResizeNearestExact: + def __init__(self, size): + self.size = size + def __call__(self, im_tuple): + return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple] + + def __repr__(self): + return "TupleResizeNearestExact(size={})".format(self.size) + + +class TupleResize(object): + def __init__(self, size, mode=InterpolationMode.BICUBIC): + self.size = size + self.resize = transforms.Resize(size, mode) + def __call__(self, im_tuple): + return [self.resize(im) for im in im_tuple] + + def __repr__(self): + return "TupleResize(size={})".format(self.size) + +class Normalize: + def __call__(self,im): + mean = im.mean(dim=(1,2), keepdims=True) + std = im.std(dim=(1,2), keepdims=True) + return (im-mean)/std + + +class TupleNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + self.normalize = transforms.Normalize(mean=mean, std=std) + + def __call__(self, im_tuple): + c,h,w = im_tuple[0].shape + if c > 3: + warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb") + return [self.normalize(im[:3]) for im in im_tuple] + + def __repr__(self): + return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) + + +class TupleCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, im_tuple): + for t in self.transforms: + im_tuple = t(im_tuple) + return im_tuple + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + +@torch.no_grad() +def cls_to_flow(cls, deterministic_sampling = True): + B,C,H,W = cls.shape + device = cls.device + res = round(math.sqrt(C)) + G = torch.meshgrid( + *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)], + indexing = 'ij' + ) + G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + if deterministic_sampling: + sampled_cls = cls.max(dim=1).indices + else: + sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W) + flow = G[sampled_cls] + return flow + +@torch.no_grad() +def cls_to_flow_refine(cls): + B,C,H,W = cls.shape + device = cls.device + res = round(math.sqrt(C)) + G = torch.meshgrid( + *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)], + indexing = 'ij' + ) + G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + # FIXME: below softmax line causes mps to bug, don't know why. + if device.type == 'mps': + cls = cls.log_softmax(dim=1).exp() + else: + cls = cls.softmax(dim=1) + mode = cls.max(dim=1).indices + + index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long() + neighbours = torch.gather(cls, dim = 1, index = index)[...,None] + flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]] + tot_prob = neighbours.sum(dim=1) + flow = flow / tot_prob + return flow + + +def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): + + if H is None: + B,H,W = depth1.shape + else: + B = depth1.shape[0] + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=depth1.device + ) + for n in (B, H, W) + ], + indexing = 'ij' + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + depth_interpolation_mode = depth_interpolation_mode, + relative_depth_error_threshold = relative_depth_error_threshold, + ) + prob = mask.float().reshape(B, H, W) + x2 = x2.reshape(B, H, W, 2) + return x2, prob + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + if depth_interpolation_mode == "combined": + # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation + if smooth_mask: + raise NotImplementedError("Combined bilinear and NN warp not implemented") + valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "bilinear", + relative_depth_error_threshold = relative_depth_error_threshold) + valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "nearest-exact", + relative_depth_error_threshold = relative_depth_error_threshold) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + warp = warp_bilinear.clone() + warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + valid = valid_bilinear | valid_nearest + return valid, warp + + + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + )[:, 0, :, 0] + + relative_depth_error = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() + if not smooth_mask: + consistent_mask = relative_depth_error < relative_depth_error_threshold + else: + consistent_mask = (-relative_depth_error/smooth_mask).exp() + valid_mask = nonzero_mask * covisible_mask * consistent_mask + if return_relative_depth_error: + return relative_depth_error, w_kpts0 + else: + return valid_mask, w_kpts0 + +imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) +imagenet_std = torch.tensor([0.229, 0.224, 0.225]) + + +def numpy_to_pil(x: np.ndarray): + """ + Args: + x: Assumed to be of shape (h,w,c) + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if x.max() <= 1.01: + x *= 255 + x = x.astype(np.uint8) + return Image.fromarray(x) + + +def tensor_to_pil(x, unnormalize=False): + if unnormalize: + x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) + x = x.detach().permute(1, 2, 0).cpu().numpy() + x = np.clip(x, 0.0, 1.0) + return numpy_to_pil(x) + + +def to_cuda(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cuda() + return batch + + +def to_cpu(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cpu() + return batch + + +def get_pose(calib): + w, h = np.array(calib["imsize"])[0] + return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w + + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans + +@torch.no_grad() +def reset_opt(opt): + for group in opt.param_groups: + for p in group['params']: + if p.requires_grad: + state = opt.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + +def flow_to_pixel_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + w1 * (flow[..., 0] + 1) / 2, + h1 * (flow[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return flow + +to_pixel_coords = flow_to_pixel_coords # just an alias + +def flow_to_normalized_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + 2 * (flow[..., 0]) / w1 - 1, + 2 * (flow[..., 1]) / h1 - 1, + ), + axis=-1, + ) + ) + return flow + +to_normalized_coords = flow_to_normalized_coords # just an alias + +def warp_to_pixel_coords(warp, h1, w1, h2, w2): + warp1 = warp[..., :2] + warp1 = ( + torch.stack( + ( + w1 * (warp1[..., 0] + 1) / 2, + h1 * (warp1[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + warp2 = warp[..., 2:] + warp2 = ( + torch.stack( + ( + w2 * (warp2[..., 0] + 1) / 2, + h2 * (warp2[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return torch.cat((warp1,warp2), dim=-1) + + + +def signed_point_line_distance(point, line, eps: float = 1e-9): + r"""Return the distance from points to lines. + + Args: + point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`. + line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`. + eps: Small constant for safe sqrt. + + Returns: + the computed distance with shape :math:`(*, N)`. + """ + + if not point.shape[-1] in (2, 3): + raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}") + + if not line.shape[-1] == 3: + raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}") + + numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]) + denominator = line[..., :2].norm(dim=-1) + + return numerator / (denominator + eps) + + +def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): + r"""Return one-sided epipolar distance for correspondences given the fundamental matrix. + + This method measures the distance from points in the right images to the epilines + of the corresponding points in the left images as they reflect in the right images. + + Args: + pts1: correspondences from the left images with shape + :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. + pts2: correspondences from the right images with shape + :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. + Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to + avoid ambiguity with torch.nn.functional. + + Returns: + the computed Symmetrical distance with shape :math:`(*, N)`. + """ + import kornia + if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3): + raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}") + + if pts1.shape[-1] == 2: + pts1 = kornia.geometry.convert_points_to_homogeneous(pts1) + + F_t = Fm.transpose(dim0=-2, dim1=-1) + line1_in_2 = pts1 @ F_t + + return signed_point_line_distance(pts2, line1_in_2) + +def get_grid(b, h, w, device): + grid = torch.meshgrid( + *[ + torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) + for n in (b, h, w) + ], + indexing = 'ij' + ) + grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2) + return grid + + +def get_autocast_params(device=None, enabled=False, dtype=None): + if device is None: + autocast_device = "cuda" if torch.cuda.is_available() else "cpu" + else: + #strip :X from device + autocast_device = str(device).split(":")[0] + if 'cuda' in str(device): + out_dtype = dtype + enabled = True + else: + out_dtype = torch.bfloat16 + enabled = False + # mps is not supported + autocast_device = "cpu" + return autocast_device, enabled, out_dtype \ No newline at end of file diff --git a/imcui/third_party/RoMa/setup.py b/imcui/third_party/RoMa/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec18f3bbb71b85d943fdfeed3ed5c47033aebbc --- /dev/null +++ b/imcui/third_party/RoMa/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup, find_packages + +setup( + name="romatch", + packages=find_packages(include=("romatch*",)), + version="0.0.1", + author="Johan Edstedt", + install_requires=open("requirements.txt", "r").read().split("\n"), +) diff --git a/imcui/third_party/RoRD/demo/__init__.py b/imcui/third_party/RoRD/demo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/RoRD/demo/register.py b/imcui/third_party/RoRD/demo/register.py new file mode 100644 index 0000000000000000000000000000000000000000..ba626920887639c6c95f869231d8080de64c2ee8 --- /dev/null +++ b/imcui/third_party/RoRD/demo/register.py @@ -0,0 +1,265 @@ +import numpy as np +import copy +import argparse +import os, sys +import open3d as o3d +from sys import argv +from PIL import Image +import math +import cv2 +import torch + +sys.path.append("../") +from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching +from lib.model_test import D2Net + +#### Cuda #### +use_cuda = torch.cuda.is_available() +device = torch.device('cuda:0' if use_cuda else 'cpu') + +#### Argument Parsing #### +parser = argparse.ArgumentParser(description='RoRD ICP evaluation') + +parser.add_argument( + '--rgb1', type=str, default = 'rgb/rgb2_1.jpg', + help='path to the rgb image1' +) +parser.add_argument( + '--rgb2', type=str, default = 'rgb/rgb2_2.jpg', + help='path to the rgb image2' +) + +parser.add_argument( + '--depth1', type=str, default = 'depth/depth2_1.png', + help='path to the depth image1' +) + +parser.add_argument( + '--depth2', type=str, default = 'depth/depth2_2.png', + help='path to the depth image2' +) + +parser.add_argument( + '--model_rord', type=str, default = '../models/rord.pth', + help='path to the RoRD model for evaluation' +) + +parser.add_argument( + '--model_d2', type=str, + help='path to the vanilla D2-Net model for evaluation' +) + +parser.add_argument( + '--model_ens', action='store_true', + help='ensemble model of RoRD + D2-Net' +) + +parser.add_argument( + '--sift', action='store_true', + help='Sift' +) + +parser.add_argument( + '--camera_file', type=str, default='../configs/camera.txt', + help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.' +) + +parser.add_argument( + '--viz3d', action='store_true', + help='visualize the pointcloud registrations' +) + +args = parser.parse_args() + +if args.model_ens: # Change default paths accordingly for ensemble + model1_ens = '../../models/rord.pth' + model2_ens = '../../models/d2net.pth' + +def draw_registration_result(source, target, transformation): + source_temp = copy.deepcopy(source) + target_temp = copy.deepcopy(target) + source_temp.transform(transformation) + + target_temp += source_temp + # print("Saved registered PointCloud.") + # o3d.io.write_point_cloud("registered.pcd", target_temp) + + trgSph.append(source_temp); trgSph.append(target_temp) + axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + axis2.transform(transformation) + trgSph.append(axis1); trgSph.append(axis2) + print("Showing registered PointCloud.") + o3d.visualization.draw_geometries(trgSph) + + +def readDepth(depthFile): + depth = Image.open(depthFile) + if depth.mode != "I": + raise Exception("Depth image is not in intensity format") + + return np.asarray(depth) + +def readCamera(camera): + with open (camera, "rt") as file: + contents = file.read().split() + + focalX = float(contents[0]) + focalY = float(contents[1]) + centerX = float(contents[2]) + centerY = float(contents[3]) + scalingFactor = float(contents[4]) + + return focalX, focalY, centerX, centerY, scalingFactor + +def getPointCloud(rgbFile, depthFile, pts): + thresh = 15.0 + + depth = readDepth(depthFile) + rgb = Image.open(rgbFile) + + points = [] + colors = [] + + corIdx = [-1]*len(pts) + corPts = [None]*len(pts) + ptIdx = 0 + + for v in range(depth.shape[0]): + for u in range(depth.shape[1]): + Z = depth[v, u] / scalingFactor + if Z==0: continue + if (Z > thresh): continue + + X = (u - centerX) * Z / focalX + Y = (v - centerY) * Z / focalY + + points.append((X, Y, Z)) + colors.append(rgb.getpixel((u, v))) + + if((u, v) in pts): + # print("Point found.") + index = pts.index((u, v)) + corIdx[index] = ptIdx + corPts[index] = (X, Y, Z) + + ptIdx = ptIdx+1 + + points = np.asarray(points) + colors = np.asarray(colors) + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors/255) + + return pcd, corIdx, corPts + + +def convertPts(A): + X = A[0]; Y = A[1] + + x = []; y = [] + + for i in range(len(X)): + x.append(int(float(X[i]))) + + for i in range(len(Y)): + y.append(int(float(Y[i]))) + + pts = [] + for i in range(len(x)): + pts.append((x[i], y[i])) + + return pts + + +def getSphere(pts): + sphs = [] + + for ele in pts: + if(ele is not None): + sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03) + sphere.paint_uniform_color([0.9, 0.2, 0]) + + trans = np.identity(4) + trans[0, 3] = ele[0] + trans[1, 3] = ele[1] + trans[2, 3] = ele[2] + + sphere.transform(trans) + sphs.append(sphere) + + return sphs + + +def get3dCor(src, trg): + corr = [] + + for sId, tId in zip(src, trg): + if(sId != -1 and tId != -1): + corr.append((sId, tId)) + + corr = np.asarray(corr) + + return corr + +if __name__ == "__main__": + + focalX, focalY, centerX, centerY, scalingFactor = readCamera(args.camera_file) + + rgb_name_src = os.path.basename(args.rgb1) + H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy' + srcH = os.path.join(os.path.dirname(args.rgb1), H_name_src) + rgb_name_trg = os.path.basename(args.rgb2) + H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy' + trgH = os.path.join(os.path.dirname(args.rgb2), H_name_trg) + + use_cuda = torch.cuda.is_available() + device = torch.device('cuda:0' if use_cuda else 'cpu') + model1 = D2Net(model_file=args.model_d2) + model1 = model1.to(device) + model2 = D2Net(model_file=args.model_rord) + model2 = model2.to(device) + + if args.model_rord: + srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypoints(args.rgb1, args.rgb2, srcH, trgH, model2, device) + elif args.model_d2: + srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypoints(args.rgb1, args.rgb2, srcH, trgH, model1, device) + elif args.model_ens: + model1 = D2Net(model_file=model1_ens) + model1 = model1.to(device) + model2 = D2Net(model_file=model2_ens) + model2 = model2.to(device) + srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypointsEnsemble(model1, model2, args.rgb1, args.rgb2, srcH, trgH, device) + elif args.sift: + srcPts, trgPts, matchImg, matchImgOrtho = siftMatching(args.rgb1, args.rgb2, srcH, trgH, device) + + #### Visualization #### + print("\nShowing matches in perspective and orthographic view. Press q\n") + cv2.imshow('Orthographic view', matchImgOrtho) + cv2.imshow('Perspective view', matchImg) + cv2.waitKey() + + srcPts = convertPts(srcPts) + trgPts = convertPts(trgPts) + + srcCld, srcIdx, srcCor = getPointCloud(args.rgb1, args.depth1, srcPts) + trgCld, trgIdx, trgCor = getPointCloud(args.rgb2, args.depth2, trgPts) + + srcSph = getSphere(srcCor) + trgSph = getSphere(trgCor) + axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + srcSph.append(srcCld); srcSph.append(axis) + trgSph.append(trgCld); trgSph.append(axis) + + corr = get3dCor(srcIdx, trgIdx) + + p2p = o3d.registration.TransformationEstimationPointToPoint() + trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr)) + print("Transformation matrix: \n", trans_init) + + if args.viz3d: + # o3d.visualization.draw_geometries(srcSph) + # o3d.visualization.draw_geometries(trgSph) + + draw_registration_result(srcCld, trgCld, trans_init) diff --git a/imcui/third_party/RoRD/evaluation/DiverseView/evalRT.py b/imcui/third_party/RoRD/evaluation/DiverseView/evalRT.py new file mode 100644 index 0000000000000000000000000000000000000000..d0be9aef58e408668112e0587a03b2b33012a342 --- /dev/null +++ b/imcui/third_party/RoRD/evaluation/DiverseView/evalRT.py @@ -0,0 +1,307 @@ +import numpy as np +import argparse +import copy +import os, sys +import open3d as o3d +from sys import argv, exit +from PIL import Image +import math +from tqdm import tqdm +import cv2 + + +sys.path.append("../../") + +from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching +import pandas as pd + + +import torch +from lib.model_test import D2Net + +#### Cuda #### +use_cuda = torch.cuda.is_available() +device = torch.device('cuda:0' if use_cuda else 'cpu') + +#### Argument Parsing #### +parser = argparse.ArgumentParser(description='RoRD ICP evaluation on a DiverseView dataset sequence.') + +parser.add_argument('--dataset', type=str, default='/scratch/udit/realsense/RoRD_data/preprocessed/', + help='path to the dataset folder') + +parser.add_argument('--sequence', type=str, default='data1') + +parser.add_argument( + '--output_dir', type=str, default='out', + help='output directory for RT estimates' +) + +parser.add_argument( + '--model_rord', type=str, help='path to the RoRD model for evaluation' +) + +parser.add_argument( + '--model_d2', type=str, help='path to the vanilla D2-Net model for evaluation' +) + +parser.add_argument( + '--model_ens', action='store_true', + help='ensemble model of RoRD + D2-Net' +) + +parser.add_argument( + '--sift', action='store_true', + help='Sift' +) + +parser.add_argument( + '--viz3d', action='store_true', + help='visualize the pointcloud registrations' +) + +parser.add_argument( + '--log_interval', type=int, default=9, + help='Matched image logging interval' +) + +parser.add_argument( + '--camera_file', type=str, default='../../configs/camera.txt', + help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.' +) + +parser.add_argument( + '--persp', action='store_true', default=False, + help='Feature matching on perspective images.' +) + +parser.set_defaults(fp16=False) +args = parser.parse_args() + + +if args.model_ens: # Change default paths accordingly for ensemble + model1_ens = '../../models/rord.pth' + model2_ens = '../../models/d2net.pth' + +def draw_registration_result(source, target, transformation): + source_temp = copy.deepcopy(source) + target_temp = copy.deepcopy(target) + source_temp.transform(transformation) + trgSph.append(source_temp); trgSph.append(target_temp) + axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + axis2.transform(transformation) + trgSph.append(axis1); trgSph.append(axis2) + o3d.visualization.draw_geometries(trgSph) + +def readDepth(depthFile): + depth = Image.open(depthFile) + if depth.mode != "I": + raise Exception("Depth image is not in intensity format") + + return np.asarray(depth) + +def readCamera(camera): + with open (camera, "rt") as file: + contents = file.read().split() + + focalX = float(contents[0]) + focalY = float(contents[1]) + centerX = float(contents[2]) + centerY = float(contents[3]) + scalingFactor = float(contents[4]) + + return focalX, focalY, centerX, centerY, scalingFactor + + +def getPointCloud(rgbFile, depthFile, pts): + thresh = 15.0 + + depth = readDepth(depthFile) + rgb = Image.open(rgbFile) + + points = [] + colors = [] + + corIdx = [-1]*len(pts) + corPts = [None]*len(pts) + ptIdx = 0 + + for v in range(depth.shape[0]): + for u in range(depth.shape[1]): + Z = depth[v, u] / scalingFactor + if Z==0: continue + if (Z > thresh): continue + + X = (u - centerX) * Z / focalX + Y = (v - centerY) * Z / focalY + + points.append((X, Y, Z)) + colors.append(rgb.getpixel((u, v))) + + if((u, v) in pts): + index = pts.index((u, v)) + corIdx[index] = ptIdx + corPts[index] = (X, Y, Z) + + ptIdx = ptIdx+1 + + points = np.asarray(points) + colors = np.asarray(colors) + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors/255) + + return pcd, corIdx, corPts + + +def convertPts(A): + X = A[0]; Y = A[1] + + x = []; y = [] + + for i in range(len(X)): + x.append(int(float(X[i]))) + + for i in range(len(Y)): + y.append(int(float(Y[i]))) + + pts = [] + for i in range(len(x)): + pts.append((x[i], y[i])) + + return pts + + +def getSphere(pts): + sphs = [] + + for element in pts: + if(element is not None): + sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03) + sphere.paint_uniform_color([0.9, 0.2, 0]) + + trans = np.identity(4) + trans[0, 3] = element[0] + trans[1, 3] = element[1] + trans[2, 3] = element[2] + + sphere.transform(trans) + sphs.append(sphere) + + return sphs + + +def get3dCor(src, trg): + corr = [] + + for sId, tId in zip(src, trg): + if(sId != -1 and tId != -1): + corr.append((sId, tId)) + + corr = np.asarray(corr) + + return corr + +if __name__ == "__main__": + camera_file = args.camera_file + rgb_csv = args.dataset + args.sequence + '/rtImagesRgb.csv' + depth_csv = args.dataset + args.sequence + '/rtImagesDepth.csv' + + os.makedirs(os.path.join(args.output_dir, 'vis'), exist_ok=True) + dir_name = args.output_dir + os.makedirs(args.output_dir, exist_ok=True) + + focalX, focalY, centerX, centerY, scalingFactor = readCamera(camera_file) + + df_rgb = pd.read_csv(rgb_csv) + df_dep = pd.read_csv(depth_csv) + + model1 = D2Net(model_file=args.model_d2).to(device) + model2 = D2Net(model_file=args.model_rord).to(device) + + queryId = 0 + for im_q, dep_q in tqdm(zip(df_rgb['query'], df_dep['query']), total=df_rgb.shape[0]): + filter_list = [] + dbId = 0 + for im_d, dep_d in tqdm(zip(df_rgb.iteritems(), df_dep.iteritems()), total=df_rgb.shape[1]): + if im_d[0] == 'query': + continue + rgb_name_src = os.path.basename(im_q) + H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy' + srcH = args.dataset + args.sequence + '/rgb/' + H_name_src + rgb_name_trg = os.path.basename(im_d[1][1]) + H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy' + trgH = args.dataset + args.sequence + '/rgb/' + H_name_trg + + srcImg = srcH.replace('.npy', '.jpg') + trgImg = trgH.replace('.npy', '.jpg') + + if args.model_rord: + if args.persp: + srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device) + else: + srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model2, device) + + elif args.model_d2: + if args.persp: + srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device) + else: + srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model1, device) + + elif args.model_ens: + model1 = D2Net(model_file=model1_ens) + model1 = model1.to(device) + model2 = D2Net(model_file=model2_ens) + model2 = model2.to(device) + srcPts, trgPts, matchImg = getPerspKeypointsEnsemble(model1, model2, srcImg, trgImg, srcH, trgH, device) + + elif args.sift: + if args.persp: + srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, HFile1=None, HFile2=None, device=device) + else: + srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, srcH, trgH, device) + + if(isinstance(srcPts, list) == True): + print(np.identity(4)) + filter_list.append(np.identity(4)) + continue + + + srcPts = convertPts(srcPts) + trgPts = convertPts(trgPts) + + depth_name_src = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_q + depth_name_trg = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_d[1][1] + + srcCld, srcIdx, srcCor = getPointCloud(srcImg, depth_name_src, srcPts) + trgCld, trgIdx, trgCor = getPointCloud(trgImg, depth_name_trg, trgPts) + + srcSph = getSphere(srcCor) + trgSph = getSphere(trgCor) + axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) + srcSph.append(srcCld); srcSph.append(axis) + trgSph.append(trgCld); trgSph.append(axis) + + corr = get3dCor(srcIdx, trgIdx) + + p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint() + trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr)) + # print(trans_init) + filter_list.append(trans_init) + + if args.viz3d: + o3d.visualization.draw_geometries(srcSph) + o3d.visualization.draw_geometries(trgSph) + draw_registration_result(srcCld, trgCld, trans_init) + + if(dbId%args.log_interval == 0): + cv2.imwrite(os.path.join(args.output_dir, 'vis') + "/matchImg.%02d.%02d.jpg"%(queryId, dbId//args.log_interval), matchImg) + dbId += 1 + + + RT = np.stack(filter_list).transpose(1,2,0) + + np.save(os.path.join(dir_name, str(queryId) + '.npy'), RT) + queryId += 1 + print('-----check-------', RT.shape) diff --git a/imcui/third_party/RoRD/extractMatch.py b/imcui/third_party/RoRD/extractMatch.py new file mode 100644 index 0000000000000000000000000000000000000000..b413dde1334b52fef294fb0c10c2acfe5b901534 --- /dev/null +++ b/imcui/third_party/RoRD/extractMatch.py @@ -0,0 +1,195 @@ +import argparse + +import numpy as np + +import imageio + +import torch + +from tqdm import tqdm +import time +import scipy +import scipy.io +import scipy.misc +import os +import sys + +from lib.model_test import D2Net +from lib.utils import preprocess_image +from lib.pyramid import process_multiscale + +import cv2 +import matplotlib.pyplot as plt +from PIL import Image +from skimage.feature import match_descriptors +from skimage.measure import ransac +from skimage.transform import ProjectiveTransform, AffineTransform +import pydegensac + + +parser = argparse.ArgumentParser(description='Feature extraction script') +parser.add_argument('imgs', type=str, nargs=2) +parser.add_argument( + '--preprocessing', type=str, default='caffe', + help='image preprocessing (caffe or torch)' +) + +parser.add_argument( + '--model_file', type=str, + help='path to the full model' +) + +parser.add_argument( + '--no-relu', dest='use_relu', action='store_false', + help='remove ReLU after the dense feature extraction module' +) +parser.set_defaults(use_relu=True) + +parser.add_argument( + '--sift', dest='use_sift', action='store_true', + help='Show sift matching as well' +) +parser.set_defaults(use_sift=False) + + +def extract(image, args, model, device): + if len(image.shape) == 2: + image = image[:, :, np.newaxis] + image = np.repeat(image, 3, -1) + + input_image = preprocess_image( + image, + preprocessing=args.preprocessing + ) + with torch.no_grad(): + keypoints, scores, descriptors = process_multiscale( + torch.tensor( + input_image[np.newaxis, :, :, :].astype(np.float32), + device=device + ), + model, + scales=[1] + ) + + keypoints = keypoints[:, [1, 0, 2]] + + feat = {} + feat['keypoints'] = keypoints + feat['scores'] = scores + feat['descriptors'] = descriptors + + return feat + + +def rordMatching(image1, image2, feat1, feat2, matcher="BF"): + if(matcher == "BF"): + + t0 = time.time() + bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) + matches = bf.match(feat1['descriptors'], feat2['descriptors']) + matches = sorted(matches, key=lambda x:x.distance) + t1 = time.time() + print("Time to extract matches: ", t1-t0) + + print("Number of raw matches:", len(matches)) + + match1 = [m.queryIdx for m in matches] + match2 = [m.trainIdx for m in matches] + + keypoints_left = feat1['keypoints'][match1, : 2] + keypoints_right = feat2['keypoints'][match2, : 2] + + np.random.seed(0) + + t0 = time.time() + + H, inliers = pydegensac.findHomography(keypoints_left, keypoints_right, 10.0, 0.99, 10000) + + t1 = time.time() + print("Time for ransac: ", t1-t0) + + n_inliers = np.sum(inliers) + print('Number of inliers: %d.' % n_inliers) + + inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_left[inliers]] + inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_right[inliers]] + placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] + + draw_params = dict(matchColor = (0,255,0), + singlePointColor = (255,0,0), + # matchesMask = matchesMask, + flags = 0) + image3 = cv2.drawMatches(image1, inlier_keypoints_left, image2, inlier_keypoints_right, placeholder_matches, None, **draw_params) + + plt.figure(figsize=(20, 20)) + plt.imshow(image3) + plt.axis('off') + plt.show() + + +def siftMatching(img1, img2): + img1 = np.array(cv2.cvtColor(np.array(img1), cv2.COLOR_BGR2RGB)) + img2 = np.array(cv2.cvtColor(np.array(img2), cv2.COLOR_BGR2RGB)) + + # surf = cv2.xfeatures2d.SURF_create(100) + surf = cv2.xfeatures2d.SIFT_create() + + kp1, des1 = surf.detectAndCompute(img1, None) + kp2, des2 = surf.detectAndCompute(img2, None) + + FLANN_INDEX_KDTREE = 0 + index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5) + search_params = dict(checks = 50) + flann = cv2.FlannBasedMatcher(index_params, search_params) + matches = flann.knnMatch(des1,des2,k=2) + good = [] + for m, n in matches: + if m.distance < 0.7*n.distance: + good.append(m) + + src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1, 2) + dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1, 2) + + model, inliers = pydegensac.findHomography(src_pts, dst_pts, 10.0, 0.99, 10000) + + n_inliers = np.sum(inliers) + print('Number of inliers: %d.' % n_inliers) + + inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]] + inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]] + placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] + image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None) + + cv2.imshow('Matches', image3) + cv2.waitKey(0) + + src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2) + dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2) + + return src_pts, dst_pts + + +if __name__ == '__main__': + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if use_cuda else "cpu") + args = parser.parse_args() + + model = D2Net( + model_file=args.model_file, + use_relu=args.use_relu, + use_cuda=use_cuda + ) + + image1 = np.array(Image.open(args.imgs[0])) + image2 = np.array(Image.open(args.imgs[1])) + + print('--\nRoRD\n--') + feat1 = extract(image1, args, model, device) + feat2 = extract(image2, args, model, device) + print("Features extracted.") + + rordMatching(image1, image2, feat1, feat2, matcher="BF") + + if(args.use_sift): + print('--\nSIFT\n--') + siftMatching(image1, image2) diff --git a/imcui/third_party/RoRD/scripts/getRTImages.py b/imcui/third_party/RoRD/scripts/getRTImages.py new file mode 100644 index 0000000000000000000000000000000000000000..6972c349c0dc2c046c67e194ba79ea6d7da725bd --- /dev/null +++ b/imcui/third_party/RoRD/scripts/getRTImages.py @@ -0,0 +1,54 @@ +import os +import re +from sys import argv, exit +import csv +import numpy as np + + +def natural_sort(l): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] + return sorted(l, key = alphanum_key) + + +def getPairs(imgs): + queryIdxs = np.linspace(start=0, stop=len(imgs)-1, num=10).astype(int).tolist() + databaseIdxs = np.linspace(start=10, stop=len(imgs)-10, num=100).astype(int).tolist() + + queryImgs = [imgs[idx] for idx in queryIdxs] + databaseImgs = [imgs[idx] for idx in databaseIdxs] + + return queryImgs, databaseImgs + + +def writeCSV(qImgs, dImgs): + with open('rtImagesDepth.csv', 'w', newline='') as file: + writer = csv.writer(file) + + title = [] + title.append('query') + + for i in range(len(dImgs)): + title.append('data' + str(i+1)) + + writer.writerow(title) + + for qImg in qImgs: + row = [] + row.append(qImg) + + for dImg in dImgs: + row.append(dImg) + + writer.writerow(row) + + +if __name__ == '__main__': + rgbDir = argv[1] + rgbImgs = natural_sort([file for file in os.listdir(rgbDir) if (file.find("jpg") != -1 or file.find("png") != -1)]) + + rgbImgs = [os.path.join(rgbDir, img) for img in rgbImgs] + + queryImgs, databaseImgs = getPairs(rgbImgs) + + writeCSV(queryImgs, databaseImgs) \ No newline at end of file diff --git a/imcui/third_party/RoRD/scripts/metricRT.py b/imcui/third_party/RoRD/scripts/metricRT.py new file mode 100644 index 0000000000000000000000000000000000000000..99a323b269e79d4c8f179bae3227224beff57f6c --- /dev/null +++ b/imcui/third_party/RoRD/scripts/metricRT.py @@ -0,0 +1,63 @@ +import numpy as np +import re +import os +import argparse + + +def natural_sort(l): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] + + return sorted(l, key = alphanum_key) + + +def angular_distance_np(R_hat, R): + # measure the angular distance between two rotation matrice + # R1,R2: [n, 3, 3] + if R_hat.shape == (3,3): + R_hat = R_hat[np.newaxis,:] + if R.shape == (3,3): + R = R[np.newaxis,:] + n = R.shape[0] + trace_idx = [0,4,8] + trace = np.matmul(R_hat, R.transpose(0,2,1)).reshape(n,-1)[:,trace_idx].sum(1) + metric = np.arccos(((trace - 1)/2).clip(-1,1)) / np.pi * 180.0 + + return metric + + +def main(): + parser = argparse.ArgumentParser(description='Rotation and translation metric.') + parser.add_argument('--trans1', type=str) + parser.add_argument('--trans2', type=str) + + args = parser.parse_args() + + transFiles1 = natural_sort([file for file in os.listdir(args.trans1) if (file.find("npy") != -1 )]) + transFiles1 = [os.path.join(args.trans1, img) for img in transFiles1] + + transFiles2 = natural_sort([file for file in os.listdir(args.trans2) if (file.find("npy") != -1 )]) + transFiles2 = [os.path.join(args.trans2, img) for img in transFiles2] + + # print(len(transFiles1), transFiles1) + # print(len(transFiles2), transFiles2) + + for T1_file, T2_file in zip(transFiles1, transFiles2): + T1 = np.load(T1_file) + T2 = np.load(T2_file) + print("Shapes: ", T1.shape, T2.shape) + + for i in range(T1.shape[2]): + R1 = T1[:3, :3, i] + R2 = T2[:3, :3, i] + t1 = T1[:4, -1, i] + t2 = T2[:4, -1, i] + + R_norm = angular_distance_np(R1.reshape(1,3,3), R2.reshape(1,3,3))[0] + + print("R norm:", R_norm) + exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/imcui/third_party/RoRD/trainPT_ipr.py b/imcui/third_party/RoRD/trainPT_ipr.py new file mode 100644 index 0000000000000000000000000000000000000000..f730bbb52338509956e9979ddb07d5bef0bd57d0 --- /dev/null +++ b/imcui/third_party/RoRD/trainPT_ipr.py @@ -0,0 +1,225 @@ +import argparse +import numpy as np +import os +import sys + +import shutil + +import torch +import torch.optim as optim + +from torch.utils.data import DataLoader + +from tqdm import tqdm + +import warnings + +from lib.exceptions import NoGradientError +from lib.losses.lossPhotoTourism import loss_function +from lib.model import D2Net +from lib.dataloaders.datasetPhotoTourism_ipr import PhotoTourismIPR + + +# CUDA +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") + +# Seed +torch.manual_seed(1) +if use_cuda: + torch.cuda.manual_seed(1) +np.random.seed(1) + +# Argument parsing +parser = argparse.ArgumentParser(description='Training script') + +parser.add_argument( + '--dataset_path', type=str, default="/scratch/udit/phototourism/", + help='path to the dataset' +) + +parser.add_argument( + '--preprocessing', type=str, default='caffe', + help='image preprocessing (caffe or torch)' +) + +parser.add_argument( + '--init_model', type=str, default='models/d2net.pth', + help='path to the initial model' +) + +parser.add_argument( + '--num_epochs', type=int, default=10, + help='number of training epochs' +) +parser.add_argument( + '--lr', type=float, default=1e-3, + help='initial learning rate' +) +parser.add_argument( + '--batch_size', type=int, default=1, + help='batch size' +) +parser.add_argument( + '--num_workers', type=int, default=16, + help='number of workers for data loading' +) + +parser.add_argument( + '--log_interval', type=int, default=250, + help='loss logging interval' +) + +parser.add_argument( + '--log_file', type=str, default='log.txt', + help='loss logging file' +) + +parser.add_argument( + '--plot', dest='plot', action='store_true', + help='plot training pairs' +) +parser.set_defaults(plot=False) + +parser.add_argument( + '--checkpoint_directory', type=str, default='checkpoints', + help='directory for training checkpoints' +) +parser.add_argument( + '--checkpoint_prefix', type=str, default='rord', + help='prefix for training checkpoints' +) + +args = parser.parse_args() +print(args) + +# Creating CNN model +model = D2Net( + model_file=args.init_model, + use_cuda=False +) +model = model.to(device) + +# Optimizer +optimizer = optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr +) + +training_dataset = PhotoTourismIPR( + base_path=args.dataset_path, + preprocessing=args.preprocessing +) +training_dataset.build_dataset() + +training_dataloader = DataLoader( + training_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers +) + +# Define epoch function +def process_epoch( + epoch_idx, + model, loss_function, optimizer, dataloader, device, + log_file, args, train=True, plot_path=None +): + epoch_losses = [] + + torch.set_grad_enabled(train) + + progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) + for batch_idx, batch in progress_bar: + if train: + optimizer.zero_grad() + + batch['train'] = train + batch['epoch_idx'] = epoch_idx + batch['batch_idx'] = batch_idx + batch['batch_size'] = args.batch_size + batch['preprocessing'] = args.preprocessing + batch['log_interval'] = args.log_interval + + try: + loss = loss_function(model, batch, device, plot=args.plot, plot_path=plot_path) + except NoGradientError: + # print("failed") + continue + + current_loss = loss.data.cpu().numpy()[0] + epoch_losses.append(current_loss) + + progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) + + if batch_idx % args.log_interval == 0: + log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( + 'train' if train else 'valid', + epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) + )) + + if train: + loss.backward() + optimizer.step() + + log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( + 'train' if train else 'valid', + epoch_idx, + np.mean(epoch_losses) + )) + log_file.flush() + + return np.mean(epoch_losses) + + +# Create the checkpoint directory +checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix) +if os.path.isdir(checkpoint_directory): + print('[Warning] Checkpoint directory already exists.') +else: + os.makedirs(checkpoint_directory, exist_ok=True) + +# Open the log file for writing +log_file = os.path.join(checkpoint_directory,args.log_file) +if os.path.exists(log_file): + print('[Warning] Log file already exists.') +log_file = open(log_file, 'a+') + +# Create the folders for plotting if need be +plot_path=None +if args.plot: + plot_path = os.path.join(checkpoint_directory,'train_vis') + if os.path.isdir(plot_path): + print('[Warning] Plotting directory already exists.') + else: + os.makedirs(plot_path, exist_ok=True) + + +# Initialize the history +train_loss_history = [] + +# Start the training +for epoch_idx in range(1, args.num_epochs + 1): + # Process epoch + train_loss_history.append( + process_epoch( + epoch_idx, + model, loss_function, optimizer, training_dataloader, device, + log_file, args, train=True, plot_path=plot_path + ) + ) + + # Save the current checkpoint + checkpoint_path = os.path.join( + checkpoint_directory, + '%02d.pth' % (epoch_idx) + ) + checkpoint = { + 'args': args, + 'epoch_idx': epoch_idx, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'train_loss_history': train_loss_history, + } + torch.save(checkpoint, checkpoint_path) + +# Close the log file +log_file.close() diff --git a/imcui/third_party/RoRD/trainers/trainPT_combined.py b/imcui/third_party/RoRD/trainers/trainPT_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..a32fcf00937a451195270bc5f2e3e4f43af36237 --- /dev/null +++ b/imcui/third_party/RoRD/trainers/trainPT_combined.py @@ -0,0 +1,289 @@ + +import argparse +import numpy as np +import os +import sys +sys.path.append("../") + +import shutil + +import torch +import torch.optim as optim + +from torch.utils.data import DataLoader + +from tqdm import tqdm + +import warnings + +# from lib.dataset import MegaDepthDataset + +from lib.exceptions import NoGradientError +from lib.loss import loss_function as orig_loss +from lib.losses.lossPhotoTourism import loss_function as ipr_loss +from lib.model import D2Net +from lib.dataloaders.datasetPhotoTourism_combined import PhotoTourismCombined + + +# CUDA +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:1" if use_cuda else "cpu") + +# Seed +torch.manual_seed(1) +if use_cuda: + torch.cuda.manual_seed(1) +np.random.seed(1) + +# Argument parsing +parser = argparse.ArgumentParser(description='Training script') + +parser.add_argument( + '--dataset_path', type=str, default="/scratch/udit/phototourism/", + help='path to the dataset' +) +# parser.add_argument( +# '--scene_info_path', type=str, required=True, +# help='path to the processed scenes' +# ) + +parser.add_argument( + '--preprocessing', type=str, default='caffe', + help='image preprocessing (caffe or torch)' +) + +parser.add_argument( + '--model_file', type=str, default='models/d2_ots.pth', + help='path to the full model' +) + +parser.add_argument( + '--num_epochs', type=int, default=10, + help='number of training epochs' +) +parser.add_argument( + '--lr', type=float, default=1e-3, + help='initial learning rate' +) +parser.add_argument( + '--batch_size', type=int, default=1, + help='batch size' +) +parser.add_argument( + '--num_workers', type=int, default=16, + help='number of workers for data loading' +) + +parser.add_argument( + '--use_validation', dest='use_validation', action='store_true', + help='use the validation split' +) +parser.set_defaults(use_validation=False) + +parser.add_argument( + '--log_interval', type=int, default=250, + help='loss logging interval' +) + +parser.add_argument( + '--log_file', type=str, default='log.txt', + help='loss logging file' +) + +parser.add_argument( + '--plot', dest='plot', action='store_true', + help='plot training pairs' +) +parser.set_defaults(plot=False) + +parser.add_argument( + '--checkpoint_directory', type=str, default='checkpoints', + help='directory for training checkpoints' +) +parser.add_argument( + '--checkpoint_prefix', type=str, default='d2', + help='prefix for training checkpoints' +) + +args = parser.parse_args() +print(args) + +# Creating CNN model +model = D2Net( + model_file=args.model_file, + use_cuda=False +) +model = model.to(device) + +# Optimizer +optimizer = optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr +) + +# Dataset +if args.use_validation: + validation_dataset = PhotoTourismCombined( + # scene_list_path='megadepth_utils/valid_scenes.txt', + # scene_info_path=args.scene_info_path, + base_path=args.dataset_path, + train=False, + preprocessing=args.preprocessing, + pairs_per_scene=25 + ) + # validation_dataset.build_dataset() + validation_dataloader = DataLoader( + validation_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers + ) + +training_dataset = PhotoTourismCombined( + # scene_list_path='megadepth_utils/train_scenes.txt', + # scene_info_path=args.scene_info_path, + base_path=args.dataset_path, + preprocessing=args.preprocessing +) +# training_dataset.build_dataset() + +training_dataloader = DataLoader( + training_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers +) + + +# Define epoch function +def process_epoch( + epoch_idx, + model, loss_function, optimizer, dataloader, device, + log_file, args, train=True, plot_path=None +): + epoch_losses = [] + + torch.set_grad_enabled(train) + + progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) + for batch_idx, (batch,method) in progress_bar: + if train: + optimizer.zero_grad() + + batch['train'] = train + batch['epoch_idx'] = epoch_idx + batch['batch_idx'] = batch_idx + batch['batch_size'] = args.batch_size + batch['preprocessing'] = args.preprocessing + batch['log_interval'] = args.log_interval + + try: + loss = loss_function[method](model, batch, device, plot=args.plot, plot_path=plot_path) + except NoGradientError: + # print("failed") + continue + + current_loss = loss.data.cpu().numpy()[0] + epoch_losses.append(current_loss) + + progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) + + if batch_idx % args.log_interval == 0: + log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( + 'train' if train else 'valid', + epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) + )) + + if train: + loss.backward() + optimizer.step() + + log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( + 'train' if train else 'valid', + epoch_idx, + np.mean(epoch_losses) + )) + log_file.flush() + + return np.mean(epoch_losses) + + +# Create the checkpoint directory +checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix) +if os.path.isdir(checkpoint_directory): + print('[Warning] Checkpoint directory already exists.') +else: + os.makedirs(checkpoint_directory, exist_ok=True) + +# Open the log file for writing +log_file = os.path.join(checkpoint_directory,args.log_file) +if os.path.exists(log_file): + print('[Warning] Log file already exists.') +log_file = open(log_file, 'a+') + +# Create the folders for plotting if need be +plot_path=None +if args.plot: + plot_path = os.path.join(checkpoint_directory,'train_vis') + if os.path.isdir(plot_path): + print('[Warning] Plotting directory already exists.') + else: + os.makedirs(plot_path, exist_ok=True) + + +# Initialize the history +train_loss_history = [] +validation_loss_history = [] +if args.use_validation: + min_validation_loss = process_epoch( + 0, + model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device, + log_file, args, + train=False + ) + +# Start the training +for epoch_idx in range(1, args.num_epochs + 1): + # Process epoch + train_loss_history.append( + process_epoch( + epoch_idx, + model, [orig_loss, ipr_loss], optimizer, training_dataloader, device, + log_file, args, train=True, plot_path=plot_path + ) + ) + + if args.use_validation: + validation_loss_history.append( + process_epoch( + epoch_idx, + model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device, + log_file, args, + train=False + ) + ) + + # Save the current checkpoint + checkpoint_path = os.path.join( + checkpoint_directory, + '%02d.pth' % (epoch_idx) + ) + checkpoint = { + 'args': args, + 'epoch_idx': epoch_idx, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'train_loss_history': train_loss_history, + 'validation_loss_history': validation_loss_history + } + torch.save(checkpoint, checkpoint_path) + if ( + args.use_validation and + validation_loss_history[-1] < min_validation_loss + ): + min_validation_loss = validation_loss_history[-1] + best_checkpoint_path = os.path.join( + checkpoint_directory, + '%s.best.pth' % args.checkpoint_prefix + ) + shutil.copy(checkpoint_path, best_checkpoint_path) + +# Close the log file +log_file.close() diff --git a/imcui/third_party/SGMNet/components/__init__.py b/imcui/third_party/SGMNet/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c10d2027efcf985c68abf7185f28b947012cae45 --- /dev/null +++ b/imcui/third_party/SGMNet/components/__init__.py @@ -0,0 +1,3 @@ +from . import extractors +from . import matchers +from .load_component import load_component \ No newline at end of file diff --git a/imcui/third_party/SGMNet/components/evaluators.py b/imcui/third_party/SGMNet/components/evaluators.py new file mode 100644 index 0000000000000000000000000000000000000000..59bf0bd7ce3dd085dc86072fc41bad24b9805991 --- /dev/null +++ b/imcui/third_party/SGMNet/components/evaluators.py @@ -0,0 +1,127 @@ +import numpy as np +import sys +import os +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, ROOT_DIR) + +from utils import evaluation_utils,metrics,fm_utils +import cv2 + +class auc_eval: + def __init__(self,config): + self.config=config + self.err_r,self.err_t,self.err=[],[],[] + self.ms=[] + self.precision=[] + + def run(self,info): + E,r_gt,t_gt=info['e'],info['r_gt'],info['t_gt'] + K1,K2,img1,img2=info['K1'],info['K2'],info['img1'],info['img2'] + corr1,corr2=info['corr1'],info['corr2'] + corr1,corr2=evaluation_utils.normalize_intrinsic(corr1,K1),evaluation_utils.normalize_intrinsic(corr2,K2) + size1,size2=max(img1.shape),max(img2.shape) + scale1,scale2=self.config['rescale']/size1,self.config['rescale']/size2 + #ransac + ransac_th=4./((K1[0,0]+K1[1,1])*scale1+(K2[0,0]+K2[1,1])*scale2) + R_hat,t_hat,E_hat=self.estimate(corr1,corr2,ransac_th) + #get pose error + err_r, err_t=metrics.evaluate_R_t(r_gt,t_gt,R_hat,t_hat) + err=max(err_r,err_t) + + if len(corr1)>1: + inlier_mask=metrics.compute_epi_inlier(corr1,corr2,E,self.config['inlier_th']) + precision=inlier_mask.mean() + ms=inlier_mask.sum()/len(info['x1']) + else: + ms=precision=0 + + return {'err_r':err_r,'err_t':err_t,'err':err,'ms':ms,'precision':precision} + + def res_inqueue(self,res): + self.err_r.append(res['err_r']),self.err_t.append(res['err_t']),self.err.append(res['err']) + self.ms.append(res['ms']),self.precision.append(res['precision']) + + def estimate(self,corr1,corr2,th): + num_inlier = -1 + if corr1.shape[0] >= 5: + E, mask_new = cv2.findEssentialMat(corr1, corr2,method=cv2.RANSAC, threshold=th,prob=1-1e-5) + if E is None: + E=[np.eye(3)] + for _E in np.split(E, len(E) / 3): + _num_inlier, _R, _t, _ = cv2.recoverPose(_E, corr1, corr2,np.eye(3), 1e9,mask=mask_new) + if _num_inlier > num_inlier: + num_inlier = _num_inlier + R = _R + t = _t + E = _E + else: + E,R,t=np.eye(3),np.eye(3),np.zeros(3) + return R,t,E + + def parse(self): + ths = np.arange(7) * 5 + approx_auc=metrics.approx_pose_auc(self.err,ths) + exact_auc=metrics.pose_auc(self.err,ths) + mean_pre,mean_ms=np.mean(np.asarray(self.precision)),np.mean(np.asarray(self.ms)) + + print('auc th: ',ths[1:]) + print('approx auc: ',approx_auc) + print('exact auc: ', exact_auc) + print('mean match score: ',mean_ms*100) + print('mean precision: ',mean_pre*100) + + + +class FMbench_eval: + + def __init__(self,config): + self.config=config + self.pre,self.pre_post,self.sgd=[],[],[] + self.num_corr,self.num_corr_post=[],[] + + def run(self,info): + corr1,corr2=info['corr1'],info['corr2'] + F=info['f'] + img1,img2=info['img1'],info['img2'] + + if len(corr1)>1: + pre_bf=fm_utils.compute_inlier_rate(corr1,corr2,np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean() + F_hat,mask_F=cv2.findFundamentalMat(corr1,corr2,method=cv2.FM_RANSAC,ransacReprojThreshold=1,confidence=1-1e-5) + if F_hat is None: + F_hat=np.ones([3,3]) + mask_F=np.ones([len(corr1)]).astype(bool) + else: + mask_F=mask_F.squeeze().astype(bool) + F_hat=F_hat[:3] + pre_af=fm_utils.compute_inlier_rate(corr1[mask_F],corr2[mask_F],np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean() + num_corr_af=mask_F.sum() + num_corr=len(corr1) + sgd=fm_utils.compute_SGD(F,F_hat,np.flip(img1.shape[:2]),np.flip(img2.shape[:2])) + else: + pre_bf,pre_af,sgd=0,0,1e8 + num_corr,num_corr_af=0,0 + return {'pre':pre_bf,'pre_post':pre_af,'sgd':sgd,'num_corr':num_corr,'num_corr_post':num_corr_af} + + + def res_inqueue(self,res): + self.pre.append(res['pre']),self.pre_post.append(res['pre_post']),self.sgd.append(res['sgd']) + self.num_corr.append(res['num_corr']),self.num_corr_post.append(res['num_corr_post']) + + def parse(self): + for seq_index in range(len(self.config['seq'])): + seq=self.config['seq'][seq_index] + offset=seq_index*1000 + pre=np.asarray(self.pre)[offset:offset+1000].mean() + pre_post=np.asarray(self.pre_post)[offset:offset+1000].mean() + num_corr=np.asarray(self.num_corr)[offset:offset+1000].mean() + num_corr_post=np.asarray(self.num_corr_post)[offset:offset+1000].mean() + f_recall=(np.asarray(self.sgd)[offset:offset+1000]self.p_th,index[:,0],index2.squeeze(0) + mask_mc=index2[index] == torch.arange(len(p)).cuda() + mask=mask_th&mask_mc + index1,index2=torch.nonzero(mask).squeeze(1),index[mask] + return index1,index2 + + +class NN_Matcher(object): + + def __init__(self,config): + config=namedtuple('config',config.keys())(*config.values()) + self.mutual_check=config.mutual_check + self.ratio_th=config.ratio_th + + def run(self,test_data): + desc1,desc2,x1,x2=test_data['desc1'],test_data['desc2'],test_data['x1'],test_data['x2'] + desc_mat=np.sqrt(abs((desc1**2).sum(-1)[:,np.newaxis]+(desc2**2).sum(-1)[np.newaxis]-2*desc1@desc2.T)) + nn_index=np.argpartition(desc_mat,kth=(1,2),axis=-1) + dis_value12=np.take_along_axis(desc_mat,nn_index, axis=-1) + ratio_score=dis_value12[:,0]/dis_value12[:,1] + nn_index1=nn_index[:,0] + nn_index2=np.argmin(desc_mat,axis=0) + mask_ratio,mask_mutual=ratio_scoreself.config['angle_th'][0],angle_listself.config['overlap_th'][0],overlap_scoreself.config['min_corr'] and len(incorr_index1)>self.config['min_incorr'] and len(incorr_index2)>self.config['min_incorr']: + info['corr'].append(corr_index),info['incorr1'].append(incorr_index1),info['incorr2'].append(incorr_index2) + info['dR'].append(dR),info['dt'].append(dt),info['K1'].append(K1),info['K2'].append(K2),info['img_path1'].append(img_path1),info['img_path2'].append(img_path2) + info['fea_path1'].append(fea_path1),info['fea_path2'].append(fea_path2),info['size1'].append(size1),info['size2'].append(size2) + sample_number+=1 + if sample_number==sample_target: + break + info['pair_num']=sample_number + #dump info + self.dump_info(seq,info) + + + def collect_meta(self): + print('collecting meta info...') + dump_path,seq_list=[],[] + if self.config['dump_train']: + dump_path.append(os.path.join(self.config['dataset_dump_dir'],'train')) + seq_list.append(self.train_list) + if self.config['dump_valid']: + dump_path.append(os.path.join(self.config['dataset_dump_dir'],'valid')) + seq_list.append(self.valid_list) + for pth,seqs in zip(dump_path,seq_list): + if not os.path.exists(pth): + os.mkdir(pth) + pair_num_list,total_pair=[],0 + for seq_index in range(len(seqs)): + seq=seqs[seq_index] + pair_num=np.loadtxt(os.path.join(self.config['dataset_dump_dir'],seq,'pair_num.txt'),dtype=int) + pair_num_list.append(str(pair_num)) + total_pair+=pair_num + pair_num_list=np.stack([np.asarray(seqs,dtype=str),np.asarray(pair_num_list,dtype=str)],axis=1) + pair_num_list=np.concatenate([np.asarray([['total',str(total_pair)]]),pair_num_list],axis=0) + np.savetxt(os.path.join(pth,'pair_num.txt'),pair_num_list,fmt='%s') + + def format_dump_data(self): + print('Formatting data...') + iteration_num=len(self.seq_list)//self.config['num_process'] + if len(self.seq_list)%self.config['num_process']!=0: + iteration_num+=1 + pool=Pool(self.config['num_process']) + for index in trange(iteration_num): + indices=range(index*self.config['num_process'],min((index+1)*self.config['num_process'],len(self.seq_list))) + pool.map(self.format_seq,indices) + pool.close() + pool.join() + + self.collect_meta() \ No newline at end of file diff --git a/imcui/third_party/SGMNet/datadump/dumper/scannet.py b/imcui/third_party/SGMNet/datadump/dumper/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..2556f727fcc9b4c621e44d9ee5cb4e99cb19b7e8 --- /dev/null +++ b/imcui/third_party/SGMNet/datadump/dumper/scannet.py @@ -0,0 +1,72 @@ +import os +import glob +import pickle +from posixpath import basename +import numpy as np +import h5py +from .base_dumper import BaseDumper + +import sys +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.insert(0, ROOT_DIR) +import utils + +class scannet(BaseDumper): + def get_seqs(self): + self.pair_list=np.loadtxt('../assets/scannet_eval_list.txt',dtype=str) + self.seq_list=np.unique(np.asarray([path.split('/')[0] for path in self.pair_list[:,0]],dtype=str)) + self.dump_seq,self.img_seq=[],[] + for seq in self.seq_list: + dump_dir=os.path.join(self.config['feature_dump_dir'],seq) + cur_img_seq=glob.glob(os.path.join(os.path.join(self.config['rawdata_dir'],seq,'img','*.jpg'))) + cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\ + +'.hdf5' for path in cur_img_seq] + self.img_seq+=cur_img_seq + self.dump_seq+=cur_dump_seq + + def format_dump_folder(self): + if not os.path.exists(self.config['feature_dump_dir']): + os.mkdir(self.config['feature_dump_dir']) + for seq in self.seq_list: + seq_dir=os.path.join(self.config['feature_dump_dir'],seq) + if not os.path.exists(seq_dir): + os.mkdir(seq_dir) + + def format_dump_data(self): + print('Formatting data...') + self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]} + + for pair in self.pair_list: + img_path1,img_path2=pair[0],pair[1] + seq=img_path1.split('/')[0] + index1,index2=int(img_path1.split('/')[-1][:-4]),int(img_path2.split('/')[-1][:-4]) + ex1,ex2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index1)+'.txt'),dtype=float),\ + np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index2)+'.txt'),dtype=float) + K1,K2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index1)+'.txt'),dtype=float),\ + np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index2)+'.txt'),dtype=float) + + + relative_extrinsic=np.matmul(np.linalg.inv(ex2),ex1) + dR,dt=relative_extrinsic[:3,:3],relative_extrinsic[:3,3] + dt /= np.sqrt(np.sum(dt**2)) + + e_gt_unnorm = np.reshape(np.matmul( + np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)), + np.reshape(dR.astype('float64'), (3, 3))), (3, 3)) + e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm) + f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1) + f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm) + + self.data['K1'].append(K1),self.data['K2'].append(K2) + self.data['R'].append(dR),self.data['T'].append(dt) + self.data['e'].append(e_gt),self.data['f'].append(f_gt) + + dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq) + fea_path1,fea_path2=os.path.join(dump_seq_dir,img_path1.split('/')[-1]+'_'+self.config['extractor']['name'] + +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\ + os.path.join(dump_seq_dir,img_path2.split('/')[-1]+'_'+self.config['extractor']['name'] + +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5') + self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2) + self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2) + + self.form_standard_dataset() diff --git a/imcui/third_party/SGMNet/datadump/dumper/yfcc.py b/imcui/third_party/SGMNet/datadump/dumper/yfcc.py new file mode 100644 index 0000000000000000000000000000000000000000..0c52e4324bba3e5ed424fe58af7a94fd3132b1e5 --- /dev/null +++ b/imcui/third_party/SGMNet/datadump/dumper/yfcc.py @@ -0,0 +1,87 @@ +import os +import glob +import pickle +import numpy as np +import h5py +from .base_dumper import BaseDumper + +import sys +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.insert(0, ROOT_DIR) +import utils + +class yfcc(BaseDumper): + + def get_seqs(self): + data_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m') + for seq in self.config['data_seq']: + for split in self.config['data_split']: + split_dir=os.path.join(data_dir,seq,split) + dump_dir=os.path.join(self.config['feature_dump_dir'],seq,split) + cur_img_seq=glob.glob(os.path.join(split_dir,'images','*.jpg')) + cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\ + +'.hdf5' for path in cur_img_seq] + self.img_seq+=cur_img_seq + self.dump_seq+=cur_dump_seq + + def format_dump_folder(self): + if not os.path.exists(self.config['feature_dump_dir']): + os.mkdir(self.config['feature_dump_dir']) + for seq in self.config['data_seq']: + seq_dir=os.path.join(self.config['feature_dump_dir'],seq) + if not os.path.exists(seq_dir): + os.mkdir(seq_dir) + for split in self.config['data_split']: + split_dir=os.path.join(seq_dir,split) + if not os.path.exists(split_dir): + os.mkdir(split_dir) + + def format_dump_data(self): + print('Formatting data...') + pair_path=os.path.join(self.config['rawdata_dir'],'pairs') + self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]} + + for seq in self.config['data_seq']: + pair_name=os.path.join(pair_path,seq+'-te-1000-pairs.pkl') + with open(pair_name, 'rb') as f: + pairs=pickle.load(f) + + #generate id list + seq_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m',seq,'test') + name_list=np.loadtxt(os.path.join(seq_dir,'images.txt'),dtype=str) + cam_name_list=np.loadtxt(os.path.join(seq_dir,'calibration.txt'),dtype=str) + + for cur_pair in pairs: + index1,index2=cur_pair[0],cur_pair[1] + cam1,cam2=h5py.File(os.path.join(seq_dir,cam_name_list[index1]),'r'),h5py.File(os.path.join(seq_dir,cam_name_list[index2]),'r') + K1,K2=cam1['K'][()],cam2['K'][()] + [w1,h1],[w2,h2]=cam1['imsize'][()][0],cam2['imsize'][()][0] + cx1,cy1,cx2,cy2 = (w1 - 1.0) * 0.5,(h1 - 1.0) * 0.5, (w2 - 1.0) * 0.5,(h2 - 1.0) * 0.5 + K1[0,2],K1[1,2],K2[0,2],K2[1,2]=cx1,cy1,cx2,cy2 + + R1,R2,t1,t2=cam1['R'][()],cam2['R'][()],cam1['T'][()].reshape([3,1]),cam2['T'][()].reshape([3,1]) + dR = np.dot(R2, R1.T) + dt = t2 - np.dot(dR, t1) + dt /= np.sqrt(np.sum(dt**2)) + + e_gt_unnorm = np.reshape(np.matmul( + np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)), + np.reshape(dR.astype('float64'), (3, 3))), (3, 3)) + e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm) + f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1) + f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm) + + self.data['K1'].append(K1),self.data['K2'].append(K2) + self.data['R'].append(dR),self.data['T'].append(dt) + self.data['e'].append(e_gt),self.data['f'].append(f_gt) + + img_path1,img_path2=os.path.join('yfcc100m',seq,'test',name_list[index1]),os.path.join('yfcc100m',seq,'test',name_list[index2]) + dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq,'test') + fea_path1,fea_path2=os.path.join(dump_seq_dir,name_list[index1].split('/')[-1]+'_'+self.config['extractor']['name'] + +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\ + os.path.join(dump_seq_dir,name_list[index2].split('/')[-1]+'_'+self.config['extractor']['name'] + +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5') + self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2) + self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2) + + self.form_standard_dataset() diff --git a/imcui/third_party/SGMNet/demo/configs/nn_config.yaml b/imcui/third_party/SGMNet/demo/configs/nn_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a87bfafce0cb7f8ab64e59311923d309aabcfab9 --- /dev/null +++ b/imcui/third_party/SGMNet/demo/configs/nn_config.yaml @@ -0,0 +1,10 @@ +extractor: + name: root + num_kpt: 4000 + resize: [-1] + det_th: 0.00001 + +matcher: + name: NN + ratio_th: 0.9 + mutual_check: True \ No newline at end of file diff --git a/imcui/third_party/SGMNet/demo/configs/sgm_config.yaml b/imcui/third_party/SGMNet/demo/configs/sgm_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91de752010daa54ef0b508ef79d2dc4ac23945ec --- /dev/null +++ b/imcui/third_party/SGMNet/demo/configs/sgm_config.yaml @@ -0,0 +1,21 @@ +extractor: + name: root + num_kpt: 4000 + resize: [-1] + det_th: 0.00001 + +matcher: + name: SGM + model_dir: ../weights/sgm/root + seed_top_k: [256,256] + seed_radius_coe: 0.01 + net_channels: 128 + layer_num: 9 + head: 4 + seedlayer: [0,6] + use_mc_seeding: True + use_score_encoding: False + conf_bar: [1.11,0.1] + sink_iter: [10,100] + detach_iter: 1000000 + p_th: 0.2 diff --git a/imcui/third_party/SGMNet/demo/demo.py b/imcui/third_party/SGMNet/demo/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe277e26d09121f5517854a7ea014b0797a2bde --- /dev/null +++ b/imcui/third_party/SGMNet/demo/demo.py @@ -0,0 +1,45 @@ +import cv2 +import yaml +import numpy as np +import os +import sys + +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, ROOT_DIR) +from components import load_component +from utils import evaluation_utils + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--config_path', type=str, default='configs/sgm_config.yaml', + help='number of processes.') +parser.add_argument('--img1_path', type=str, default='demo_1.jpg', + help='number of processes.') +parser.add_argument('--img2_path', type=str, default='demo_2.jpg', + help='number of processes.') + + +args = parser.parse_args() + +if __name__=='__main__': + with open(args.config_path, 'r') as f: + demo_config = yaml.load(f) + + extractor=load_component('extractor',demo_config['extractor']['name'],demo_config['extractor']) + + img1,img2=cv2.imread(args.img1_path),cv2.imread(args.img2_path) + size1,size2=np.flip(np.asarray(img1.shape[:2])),np.flip(np.asarray(img2.shape[:2])) + kpt1,desc1=extractor.run(args.img1_path) + kpt2,desc2=extractor.run(args.img2_path) + + matcher=load_component('matcher',demo_config['matcher']['name'],demo_config['matcher']) + test_data={'x1':kpt1,'x2':kpt2,'desc1':desc1,'desc2':desc2,'size1':size1,'size2':size2} + corr1,corr2= matcher.run(test_data) + + #draw points + dis_points_1 = evaluation_utils.draw_points(img1, kpt1) + dis_points_2 = evaluation_utils.draw_points(img2, kpt2) + + #visualize match + display=evaluation_utils.draw_match(dis_points_1,dis_points_2,corr1,corr2) + cv2.imwrite('match.png',display) diff --git a/imcui/third_party/SGMNet/evaluation/configs/cost/sg_cost.yaml b/imcui/third_party/SGMNet/evaluation/configs/cost/sg_cost.yaml new file mode 100644 index 0000000000000000000000000000000000000000..05ea5ddc7bce8ad94d3ef3ec350363b5cc846ed8 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/cost/sg_cost.yaml @@ -0,0 +1,4 @@ +net_channels: 128 +layer_num: 9 +head: 4 +use_score_encoding: True \ No newline at end of file diff --git a/imcui/third_party/SGMNet/evaluation/configs/cost/sgm_cost.yaml b/imcui/third_party/SGMNet/evaluation/configs/cost/sgm_cost.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f43193fb63fb26d50a8c3abd3cf53c43734dbca --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/cost/sgm_cost.yaml @@ -0,0 +1,11 @@ +seed_top_k: [256,256] +seed_radius_coe: 0.01 +net_channels: 128 +layer_num: 9 +head: 4 +seedlayer: [0,6] +use_mc_seeding: True +use_score_encoding: False +conf_bar: [1,0] +sink_iter: [10,10] +detach_iter: 1000000 \ No newline at end of file diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_nn.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_nn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d467a814559a27938f010dbf79a8e208551b2b5 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_nn.yaml @@ -0,0 +1,18 @@ +reader: + name: standard + rawdata_dir: FM-Bench/Dataset + dataset_dir: test_fmbench_root/fmbench_root_4000.hdf5 + num_kpt: 4000 + +matcher: + name: NN + mutual_check: False + ratio_th: 0.8 + +evaluator: + name: FM + seq: ['CPC','KITTI','TUM','Tanks_and_Temples'] + num_pair: 4000 + inlier_th: 0.003 + sgd_inlier_th: 0.05 + diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sg.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec22b1340d62fad20f22584ddbded30fcc59d1c9 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sg.yaml @@ -0,0 +1,22 @@ +reader: + name: standard + rawdata_dir: FM-Bench/Dataset + dataset_dir: test_fmbench_root/fmbench_root_4000.hdf5 + num_kpt: 4000 + +matcher: + name: SG + model_dir: ../weights/sg/root + net_channels: 128 + layer_num: 9 + head: 4 + use_score_encoding: True + sink_iter: [100] + p_th: 0.2 + +evaluator: + name: FM + seq: ['CPC','KITTI','TUM','Tanks_and_Temples'] + num_pair: 4000 + inlier_th: 0.003 + sgd_inlier_th: 0.05 diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sgm.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sgm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cd23165c95451cd44063a2b6cccea21c68fb6fa0 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sgm.yaml @@ -0,0 +1,28 @@ +reader: + name: standard + rawdata_dir: FM-Bench/Dataset + dataset_dir: test_fmbench_root/fmbench_root_4000.hdf5 + num_kpt: 4000 + +matcher: + name: SGM + model_dir: ../weights/sgm/root + seed_top_k: [256,256] + seed_radius_coe: 0.01 + net_channels: 128 + layer_num: 9 + head: 4 + seedlayer: [0,6] + use_mc_seeding: True + use_score_encoding: False + conf_bar: [1.11,0.1] #set to [1,0.1] for sp + sink_iter: [10,100] + detach_iter: 1000000 + p_th: 0.2 + +evaluator: + name: FM + seq: ['CPC','KITTI','TUM','Tanks_and_Temples'] + num_pair: 4000 + inlier_th: 0.003 + sgd_inlier_th: 0.05 diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_nn.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_nn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51ad5402b6266b60a365181371be8a5e64751d2f --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_nn.yaml @@ -0,0 +1,17 @@ +reader: + name: standard + rawdata_dir: scannet_eval + dataset_dir: scannet_test_root/scannet_root_2000.hdf5 + num_kpt: 2000 + +matcher: + name: NN + mutual_check: False + ratio_th: 0.8 + +evaluator: + name: AUC + rescale: 640 + num_pair: 1500 + inlier_th: 0.005 + diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sg.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d0ef70cfa07b1471816cc7905d6a632599d134c --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sg.yaml @@ -0,0 +1,22 @@ +reader: + name: standard + rawdata_dir: scannet_eval + dataset_dir: scannet_test_root/scannet_root_2000.hdf5 + num_kpt: 2000 + +matcher: + name: SG + model_dir: ../weights/sg/root + net_channels: 128 + layer_num: 9 + head: 4 + use_score_encoding: True + sink_iter: [100] + p_th: 0.2 + +evaluator: + name: AUC + rescale: 640 + num_pair: 1500 + inlier_th: 0.005 + diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sgm.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sgm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e524845a514e6d8d50f97bced5c9beeaed26ebe5 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sgm.yaml @@ -0,0 +1,28 @@ +reader: + name: standard + rawdata_dir: scannet_eval + dataset_dir: scannet_test_root/scannet_root_2000.hdf5 + num_kpt: 2000 + +matcher: + name: SGM + model_dir: ../weights/sgm/root + seed_top_k: [128,128] + seed_radius_coe: 0.01 + net_channels: 128 + layer_num: 9 + head: 4 + seedlayer: [0,6] + use_mc_seeding: True + use_score_encoding: False + conf_bar: [1.11,0.1] + sink_iter: [10,100] + detach_iter: 1000000 + p_th: 0.2 + +evaluator: + name: AUC + rescale: 640 + num_pair: 1500 + inlier_th: 0.005 + diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_nn.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_nn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ecd1eef2cff9b93f3665a9cf4af6bc9f68339f0 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_nn.yaml @@ -0,0 +1,17 @@ +reader: + name: standard + rawdata_dir: yfcc_rawdata + dataset_dir: yfcc_test_root/yfcc_root_2000.hdf5 + num_kpt: 2000 + +matcher: + name: NN + mutual_check: False + ratio_th: 0.8 + +evaluator: + name: AUC + rescale: 1600 + num_pair: 4000 + inlier_th: 0.005 + diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sg.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..beb2b93639160448dd955cd576e5a19a936b08f1 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sg.yaml @@ -0,0 +1,22 @@ +reader: + name: standard + rawdata_dir: yfcc_rawdata + dataset_dir: yfcc_test_root/yfcc_root_2000.hdf5 + num_kpt: 2000 + +matcher: + name: SG + model_dir: ../weights/sg/root + net_channels: 128 + layer_num: 9 + head: 4 + use_score_encoding: True + sink_iter: [100] + p_th: 0.2 + +evaluator: + name: AUC + rescale: 1600 + num_pair: 4000 + inlier_th: 0.005 + diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sgm.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sgm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c9aee8a8aa786ff209a5afadf0469f62ef2a50f --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sgm.yaml @@ -0,0 +1,28 @@ +reader: + name: standard + rawdata_dir: yfcc_rawdata + dataset_dir: yfcc_test_root/yfcc_root_2000.hdf5 + num_kpt: 2000 + +matcher: + name: SGM + model_dir: ../weights/sgm/root + seed_top_k: [128,128] + seed_radius_coe: 0.01 + net_channels: 128 + layer_num: 9 + head: 4 + seedlayer: [0,6] + use_mc_seeding: True + use_score_encoding: False + conf_bar: [1.11,0.1] #set to [1,0.1] for sp + sink_iter: [10,100] + detach_iter: 1000000 + p_th: 0.2 + +evaluator: + name: AUC + rescale: 1600 + num_pair: 4000 + inlier_th: 0.005 + diff --git a/imcui/third_party/SGMNet/evaluation/eval_cost.py b/imcui/third_party/SGMNet/evaluation/eval_cost.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3f88abc93290c96ed3d7fa8624c3534e006911 --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/eval_cost.py @@ -0,0 +1,60 @@ +import torch +import yaml +import time +from collections import OrderedDict,namedtuple +import os +import sys +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, ROOT_DIR) + +from sgmnet import matcher as SGM_Model +from superglue import matcher as SG_Model + + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--matcher_name', type=str, default='SGM', + help='number of processes.') +parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml', + help='number of processes.') +parser.add_argument('--num_kpt', type=int, default=4000, + help='keypoint number, default:100') +parser.add_argument('--iter_num', type=int, default=100, + help='keypoint number, default:100') + + +def test_cost(test_data,model): + with torch.no_grad(): + #warm up call + _=model(test_data) + torch.cuda.synchronize() + a=time.time() + for _ in range(int(args.iter_num)): + _=model(test_data) + torch.cuda.synchronize() + b=time.time() + print('Average time per run(ms): ',(b-a)/args.iter_num*1e3) + print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6) + + +if __name__=='__main__': + torch.backends.cudnn.benchmark=False + args = parser.parse_args() + with open(args.config_path, 'r') as f: + model_config = yaml.load(f) + model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) + + if args.matcher_name=='SGM': + model = SGM_Model(model_config) + elif args.matcher_name=='SG': + model = SG_Model(model_config) + model.cuda(),model.eval() + + test_data = { + 'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5, + 'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5, + 'desc1': torch.rand(1,args.num_kpt,128).cuda(), + 'desc2': torch.rand(1,args.num_kpt,128).cuda() + } + + test_cost(test_data,model) diff --git a/imcui/third_party/SGMNet/evaluation/evaluate.py b/imcui/third_party/SGMNet/evaluation/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5229375caa03b2763bf37a266fb76e80f8e25e --- /dev/null +++ b/imcui/third_party/SGMNet/evaluation/evaluate.py @@ -0,0 +1,117 @@ +import os +from torch.multiprocessing import Process,Manager,set_start_method,Pool +import functools +import argparse +import yaml +import numpy as np +import sys +import cv2 +from tqdm import trange +set_start_method('spawn',force=True) + + +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, ROOT_DIR) + +from components import load_component +from utils import evaluation_utils,metrics + +parser = argparse.ArgumentParser(description='dump eval data.') +parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml') +parser.add_argument('--num_process_match', type=int, default=4) +parser.add_argument('--num_process_eval', type=int, default=4) +parser.add_argument('--vis_folder',type=str,default=None) +args=parser.parse_args() + +def feed_match(info,matcher): + x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2] + test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) } + corr1,corr2=matcher.run(test_data) + return [corr1,corr2] + + +def reader_handler(config,read_que): + reader=load_component('reader',config['name'],config) + for index in range(len(reader)): + index+=0 + info=reader.run(index) + read_que.put(info) + read_que.put('over') + + +def match_handler(config,read_que,match_que): + matcher=load_component('matcher',config['name'],config) + match_func=functools.partial(feed_match,matcher=matcher) + pool = Pool(args.num_process_match) + cache=[] + while True: + item=read_que.get() + #clear cache + if item=='over': + if len(cache)!=0: + results=pool.map(match_func,cache) + for cur_item,cur_result in zip(cache,results): + cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] + match_que.put(cur_item) + match_que.put('over') + break + cache.append(item) + #print(len(cache)) + if len(cache)==args.num_process_match: + #matching in parallel + results=pool.map(match_func,cache) + for cur_item,cur_result in zip(cache,results): + cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] + match_que.put(cur_item) + cache=[] + pool.close() + pool.join() + + +def evaluate_handler(config,match_que): + evaluator=load_component('evaluator',config['name'],config) + pool = Pool(args.num_process_eval) + cache=[] + for _ in trange(config['num_pair']): + item=match_que.get() + if item=='over': + if len(cache)!=0: + results=pool.map(evaluator.run,cache) + for cur_res in results: + evaluator.res_inqueue(cur_res) + break + cache.append(item) + if len(cache)==args.num_process_eval: + results=pool.map(evaluator.run,cache) + for cur_res in results: + evaluator.res_inqueue(cur_res) + cache=[] + if args.vis_folder is not None: + #dump visualization + corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\ + evaluation_utils.normalize_intrinsic(item['corr2'],item['K2']) + inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th']) + display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask) + cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display) + evaluator.parse() + + +if __name__=='__main__': + with open(args.config_path, 'r') as f: + config = yaml.load(f) + if args.vis_folder is not None and not os.path.exists(args.vis_folder): + os.mkdir(args.vis_folder) + + read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100) + + read_process=Process(target=reader_handler,args=(config['reader'],read_que)) + match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que)) + evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que)) + + read_process.start() + match_process.start() + evaluate_process.start() + + read_process.join() + match_process.join() + evaluate_process.join() \ No newline at end of file diff --git a/imcui/third_party/SGMNet/sgmnet/__init__.py b/imcui/third_party/SGMNet/sgmnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..828543beceebb10d05fd9d5fdfcc4b1c91e5af6b --- /dev/null +++ b/imcui/third_party/SGMNet/sgmnet/__init__.py @@ -0,0 +1 @@ +from .match_model import matcher \ No newline at end of file diff --git a/imcui/third_party/SGMNet/sgmnet/match_model.py b/imcui/third_party/SGMNet/sgmnet/match_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8760815cba9e34749b748cdb485bdc73b1cc9edb --- /dev/null +++ b/imcui/third_party/SGMNet/sgmnet/match_model.py @@ -0,0 +1,223 @@ +import torch +import torch.nn as nn +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +eps=1e-8 + +def sinkhorn(M,r,c,iteration): + p = torch.softmax(M, dim=-1) + u = torch.ones_like(r) + v = torch.ones_like(c) + for _ in range(iteration): + u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) + v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) + p = p * u.unsqueeze(-1) * v.unsqueeze(-2) + return p + +def sink_algorithm(M,dustbin,iteration): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + r = torch.ones([M.shape[0], M.shape[1] - 1],device=device) + r = torch.cat([r, torch.ones([M.shape[0], 1],device=device) * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1],device=device) + c = torch.cat([c, torch.ones([M.shape[0], 1],device=device) * M.shape[2]], dim=-1) + p=sinkhorn(M,r,c,iteration) + return p + + +def seeding(nn_index1,nn_index2,x1,x2,topk,match_score,confbar,nms_radius,use_mc=True,test=False): + + #apply mutual check before nms + if use_mc: + mask_not_mutual=nn_index2.gather(dim=-1,index=nn_index1)!=torch.arange(nn_index1.shape[1],device=device) + match_score[mask_not_mutual]=-1 + #NMS + pos_dismat1=((x1.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x1.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x1@x1.transpose(1,2))).abs_().sqrt_() + x2=x2.gather(index=nn_index1.unsqueeze(-1).expand(-1,-1,2),dim=1) + pos_dismat2=((x2.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x2.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x2@x2.transpose(1,2))).abs_().sqrt_() + radius1, radius2 = nms_radius * pos_dismat1.mean(dim=(1,2),keepdim=True), nms_radius * pos_dismat2.mean(dim=(1,2),keepdim=True) + nms_mask = (pos_dismat1 >= radius1) & (pos_dismat2 >= radius2) + mask_not_local_max=(match_score.unsqueeze(-1)>=match_score.unsqueeze(-2))|nms_mask + mask_not_local_max=~(mask_not_local_max.min(dim=-1).values) + match_score[mask_not_local_max] = -1 + + #confidence bar + match_score[match_score0 + if test: + topk=min(mask_survive.sum(dim=1)[0]+2,topk) + _,topindex = torch.topk(match_score,topk,dim=-1)#b*k + seed_index1,seed_index2=topindex,nn_index1.gather(index=topindex,dim=-1) + return seed_index1,seed_index2 + + + +class PointCN(nn.Module): + def __init__(self, channels,out_channels): + nn.Module.__init__(self) + self.shot_cut = nn.Conv1d(channels, out_channels, kernel_size=1) + self.conv = nn.Sequential( + nn.InstanceNorm1d(channels, eps=1e-3), + nn.SyncBatchNorm(channels), + nn.ReLU(), + nn.Conv1d(channels, channels, kernel_size=1), + nn.InstanceNorm1d(channels, eps=1e-3), + nn.SyncBatchNorm(channels), + nn.ReLU(), + nn.Conv1d(channels, out_channels, kernel_size=1) + ) + + def forward(self, x): + return self.conv(x) + self.shot_cut(x) + + +class attention_propagantion(nn.Module): + + def __init__(self,channel,head): + nn.Module.__init__(self) + self.head=head + self.head_dim=channel//head + self.query_filter,self.key_filter,self.value_filter=nn.Conv1d(channel,channel,kernel_size=1),nn.Conv1d(channel,channel,kernel_size=1),\ + nn.Conv1d(channel,channel,kernel_size=1) + self.mh_filter=nn.Conv1d(channel,channel,kernel_size=1) + self.cat_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(), + nn.Conv1d(2*channel, channel, kernel_size=1)) + + def forward(self,desc1,desc2,weight_v=None): + #desc1(q) attend to desc2(k,v) + batch_size=desc1.shape[0] + query,key,value=self.query_filter(desc1).view(batch_size,self.head,self.head_dim,-1),self.key_filter(desc2).view(batch_size,self.head,self.head_dim,-1),\ + self.value_filter(desc2).view(batch_size,self.head,self.head_dim,-1) + if weight_v is not None: + value=value*weight_v.view(batch_size,1,1,-1) + score=torch.softmax(torch.einsum('bhdn,bhdm->bhnm',query,key)/ self.head_dim ** 0.5,dim=-1) + add_value=torch.einsum('bhnm,bhdm->bhdn',score,value).reshape(batch_size,self.head_dim*self.head,-1) + add_value=self.mh_filter(add_value) + desc1_new=desc1+self.cat_filter(torch.cat([desc1,add_value],dim=1)) + return desc1_new + + +class hybrid_block(nn.Module): + def __init__(self,channel,head): + nn.Module.__init__(self) + self.head=head + self.channel=channel + self.attention_block_down = attention_propagantion(channel, head) + self.cluster_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(), + nn.Conv1d(2*channel, 2*channel, kernel_size=1)) + self.cross_filter=attention_propagantion(channel,head) + self.confidence_filter=PointCN(2*channel,1) + self.attention_block_self=attention_propagantion(channel,head) + self.attention_block_up=attention_propagantion(channel,head) + + def forward(self,desc1,desc2,seed_index1,seed_index2): + cluster1, cluster2 = desc1.gather(dim=-1, index=seed_index1.unsqueeze(1).expand(-1, self.channel, -1)), \ + desc2.gather(dim=-1, index=seed_index2.unsqueeze(1).expand(-1, self.channel, -1)) + + #pooling + cluster1, cluster2 = self.attention_block_down(cluster1, desc1), self.attention_block_down(cluster2, desc2) + concate_cluster=self.cluster_filter(torch.cat([cluster1,cluster2],dim=1)) + #filtering + cluster1,cluster2=self.cross_filter(concate_cluster[:,:self.channel],concate_cluster[:,self.channel:]),\ + self.cross_filter(concate_cluster[:,self.channel:],concate_cluster[:,:self.channel]) + cluster1,cluster2=self.attention_block_self(cluster1,cluster1),self.attention_block_self(cluster2,cluster2) + #unpooling + seed_weight=self.confidence_filter(torch.cat([cluster1,cluster2],dim=1)) + seed_weight=torch.sigmoid(seed_weight).squeeze(1) + desc1_new,desc2_new=self.attention_block_up(desc1,cluster1,seed_weight),self.attention_block_up(desc2,cluster2,seed_weight) + return desc1_new,desc2_new,seed_weight + + + +class matcher(nn.Module): + def __init__(self,config): + nn.Module.__init__(self) + self.seed_top_k=config.seed_top_k + self.conf_bar=config.conf_bar + self.seed_radius_coe=config.seed_radius_coe + self.use_score_encoding=config.use_score_encoding + self.detach_iter=config.detach_iter + self.seedlayer=config.seedlayer + self.layer_num=config.layer_num + self.sink_iter=config.sink_iter + + self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1), + nn.SyncBatchNorm(32),nn.ReLU(), + nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(), + nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128),nn.ReLU(), + nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256),nn.ReLU(), + nn.Conv1d(256, config.net_channels, kernel_size=1)) + + + self.hybrid_block=nn.Sequential(*[hybrid_block(config.net_channels, config.head) for _ in range(config.layer_num)]) + self.final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1) + self.dustbin=nn.Parameter(torch.tensor(1.5,dtype=torch.float32)) + + #if reseeding + if len(config.seedlayer)!=1: + self.mid_dustbin=nn.ParameterDict({str(i):nn.Parameter(torch.tensor(2,dtype=torch.float32)) for i in config.seedlayer[1:]}) + self.mid_final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1) + + def forward(self,data,test_mode=True): + x1, x2, desc1, desc2 = data['x1'][:,:,:2], data['x2'][:,:,:2], data['desc1'], data['desc2'] + desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1) + if test_mode: + encode_x1,encode_x2=data['x1'],data['x2'] + else: + encode_x1,encode_x2=data['aug_x1'], data['aug_x2'] + + #preparation + desc_dismat=(2-2*torch.matmul(desc1,desc2.transpose(1,2))).sqrt_() + values,nn_index=torch.topk(desc_dismat,k=2,largest=False,dim=-1,sorted=True) + nn_index2=torch.min(desc_dismat,dim=1).indices.squeeze(1) + inverse_ratio_score,nn_index1=values[:,:,1]/values[:,:,0],nn_index[:,:,0]#get inverse score + + #initial seeding + seed_index1,seed_index2=seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[0],inverse_ratio_score,self.conf_bar[0],\ + self.seed_radius_coe,test=test_mode) + + #position encoding + desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2) + if not self.use_score_encoding: + encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2] + encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2) + x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2) + aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2 + + seed_weight_tower,mid_p_tower,seed_index_tower,nn_index_tower=[],[],[],[] + seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1)) + nn_index_tower.append(nn_index1) + + seed_para_index=0 + for i in range(self.layer_num): + #mid seeding + if i in self.seedlayer and i!= 0: + seed_para_index+=1 + aug_desc1,aug_desc2=self.mid_final_project(aug_desc1),self.mid_final_project(aug_desc2) + M=torch.matmul(aug_desc1.transpose(1,2),aug_desc2) + p=sink_algorithm(M,self.mid_dustbin[str(i)],self.sink_iter[seed_para_index-1]) + mid_p_tower.append(p) + #rematching with p + values,nn_index=torch.topk(p[:,:-1,:-1],k=1,dim=-1) + nn_index2=torch.max(p[:,:-1,:-1],dim=1).indices.squeeze(1) + p_match_score,nn_index1=values[:,:,0],nn_index[:,:,0] + #reseeding + seed_index1, seed_index2 = seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[seed_para_index],p_match_score,\ + self.conf_bar[seed_para_index],self.seed_radius_coe,test=test_mode) + seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1)), nn_index_tower.append(nn_index1) + if not test_mode and data['step']bhnm',query1,key1)/self.head_dim**0.5,dim=-1),\ + torch.softmax(torch.einsum('bdhn,bdhm->bhnm',query2,key2)/self.head_dim**0.5,dim=-1) + add_value1, add_value2 = torch.einsum('bhnm,bdhm->bdhn', score1, value1), torch.einsum('bhnm,bdhm->bdhn',score2, value2) + else: + score1,score2 = torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query1, key2) / self.head_dim ** 0.5,dim=-1), \ + torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query2, key1) / self.head_dim ** 0.5, dim=-1) + add_value1, add_value2 =torch.einsum('bhnm,bdhm->bdhn',score1,value2),torch.einsum('bhnm,bdhm->bdhn',score2,value1) + add_value1,add_value2=self.mh_filter(add_value1.contiguous().view(batch_size,self.head*self.head_dim,n)),self.mh_filter(add_value2.contiguous().view(batch_size,self.head*self.head_dim,m)) + fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat([fea2, add_value2], dim=1) + fea1, fea2 = fea1+self.attention_filter(fea11), fea2+self.attention_filter(fea22) + + return fea1,fea2 + + +class matcher(nn.Module): + def __init__(self, config): + nn.Module.__init__(self) + self.use_score_encoding=config.use_score_encoding + self.layer_num=config.layer_num + self.sink_iter=config.sink_iter + self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1), + nn.SyncBatchNorm(32), nn.ReLU(), + nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(), + nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128), nn.ReLU(), + nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256), nn.ReLU(), + nn.Conv1d(256, config.net_channels, kernel_size=1)) + + self.dustbin=nn.Parameter(torch.tensor(1,dtype=torch.float32,device='cuda')) + self.self_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'self') for _ in range(config.layer_num)]) + self.cross_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'cross') for _ in range(config.layer_num)]) + self.final_project=nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1) + + def forward(self,data,test_mode=True): + desc1, desc2 = data['desc1'], data['desc2'] + desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1) + desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2) + if test_mode: + encode_x1,encode_x2=data['x1'],data['x2'] + else: + encode_x1,encode_x2=data['aug_x1'], data['aug_x2'] + if not self.use_score_encoding: + encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2] + + encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2) + + x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2) + aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding+desc2 + for i in range(self.layer_num): + aug_desc1,aug_desc2=self.self_attention_block[i](aug_desc1,aug_desc2) + aug_desc1,aug_desc2=self.cross_attention_block[i](aug_desc1,aug_desc2) + + aug_desc1,aug_desc2=self.final_project(aug_desc1),self.final_project(aug_desc2) + desc_mat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2) + p = sink_algorithm(desc_mat, self.dustbin,self.sink_iter[0]) + return {'p':p} + + diff --git a/imcui/third_party/SGMNet/superpoint/__init__.py b/imcui/third_party/SGMNet/superpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..111c8882a7bc7512c6191ca86a0e71c3b1404233 --- /dev/null +++ b/imcui/third_party/SGMNet/superpoint/__init__.py @@ -0,0 +1 @@ +from .superpoint import SuperPoint \ No newline at end of file diff --git a/imcui/third_party/SGMNet/superpoint/superpoint.py b/imcui/third_party/SGMNet/superpoint/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e3ce481409264a3188270ad01aa62b1614377f --- /dev/null +++ b/imcui/third_party/SGMNet/superpoint/superpoint.py @@ -0,0 +1,140 @@ +import torch +from torch import nn + + +def simple_nms(scores, nms_radius): + assert(nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def remove_borders(keypoints, scores, b, h, w): + mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b)) + mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints, scores, k): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s): + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], + ).to(keypoints)[None] + keypoints = keypoints*2 - 1 # normalize to (-1, 1) + args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + return descriptors + + +class SuperPoint(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = {**config} + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d( + c5, self.config['descriptor_dim'], + kernel_size=1, stride=1, padding=0) + + self.load_state_dict(torch.load(config['model_path'])) + + mk = self.config['max_keypoints'] + if mk == 0 or mk < -1: + raise ValueError('\"max_keypoints\" must be positive or \"-1\"') + + print('Loaded SuperPoint model') + + def forward(self, data): + # Shared Encoder + x = self.relu(self.conv1a(data)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, c, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) + scores = simple_nms(scores, self.config['nms_radius']) + + # Extract keypoints + keypoints = [ + torch.nonzero(s > self.config['detection_threshold']) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, self.config['remove_borders'], h*8, w*8) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with highest score + if self.config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, self.config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + # Extract descriptors + descriptors = [sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors)] + + return { + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + } diff --git a/imcui/third_party/SGMNet/train/config.py b/imcui/third_party/SGMNet/train/config.py new file mode 100644 index 0000000000000000000000000000000000000000..31c4c1c6deef3d6dd568897f4202d96456586376 --- /dev/null +++ b/imcui/third_party/SGMNet/train/config.py @@ -0,0 +1,126 @@ +import argparse + +def str2bool(v): + return v.lower() in ("true", "1") + + +arg_lists = [] +parser = argparse.ArgumentParser() + + +def add_argument_group(name): + arg = parser.add_argument_group(name) + arg_lists.append(arg) + return arg + + +# ----------------------------------------------------------------------------- +# Network +net_arg = add_argument_group("Network") +net_arg.add_argument( + "--model_name", type=str,default='SGM', help="" + "model for training") +net_arg.add_argument( + "--config_path", type=str,default='configs/sgm.yaml', help="" + "config path for model") + +# ----------------------------------------------------------------------------- +# Data +data_arg = add_argument_group("Data") +data_arg.add_argument( + "--rawdata_path", type=str, default='rawdata', help="" + "path for rawdata") +data_arg.add_argument( + "--dataset_path", type=str, default='dataset', help="" + "path for dataset") +data_arg.add_argument( + "--desc_path", type=str, default='desc', help="" + "path for descriptor(kpt) dir") +data_arg.add_argument( + "--num_kpt", type=int, default=1000, help="" + "number of kpt for training") +data_arg.add_argument( + "--input_normalize", type=str, default='img', help="" + "normalize type for input kpt, img or intrinsic") +data_arg.add_argument( + "--data_aug", type=str2bool, default=True, help="" + "apply kpt coordinate homography augmentation") +data_arg.add_argument( + "--desc_suffix", type=str, default='suffix', help="" + "desc file suffix") + + +# ----------------------------------------------------------------------------- +# Loss +loss_arg = add_argument_group("loss") +loss_arg.add_argument( + "--momentum", type=float, default=0.9, help="" + "momentum") +loss_arg.add_argument( + "--seed_loss_weight", type=float, default=250, help="" + "confidence loss weight for sgm") +loss_arg.add_argument( + "--mid_loss_weight", type=float, default=1, help="" + "midseeding loss weight for sgm") +loss_arg.add_argument( + "--inlier_th", type=float, default=5e-3, help="" + "inlier threshold for epipolar distance (for sgm and visualization)") + + +# ----------------------------------------------------------------------------- +# Training +train_arg = add_argument_group("Train") +train_arg.add_argument( + "--train_lr", type=float, default=1e-4, help="" + "learning rate") +train_arg.add_argument( + "--train_batch_size", type=int, default=16, help="" + "batch size") +train_arg.add_argument( + "--gpu_id", type=str,default='0', help='id(s) for CUDA_VISIBLE_DEVICES') +train_arg.add_argument( + "--train_iter", type=int, default=1000000, help="" + "training iterations to perform") +train_arg.add_argument( + "--log_base", type=str, default="./log/", help="" + "log path") +train_arg.add_argument( + "--val_intv", type=int, default=20000, help="" + "validation interval") +train_arg.add_argument( + "--save_intv", type=int, default=1000, help="" + "summary interval") +train_arg.add_argument( + "--log_intv", type=int, default=100, help="" + "log interval") +train_arg.add_argument( + "--decay_rate", type=float, default=0.999996, help="" + "lr decay rate") +train_arg.add_argument( + "--decay_iter", type=float, default=300000, help="" + "lr decay iter") +train_arg.add_argument( + "--local_rank", type=int, default=0, help="" + "local rank for ddp") +train_arg.add_argument( + "--train_vis_folder", type=str, default='.', help="" + "visualization folder during training") + +# ----------------------------------------------------------------------------- +# Visualization +vis_arg = add_argument_group('Visualization') +vis_arg.add_argument( + "--tqdm_width", type=int, default=79, help="" + "width of the tqdm bar" +) + +def get_config(): + config, unparsed = parser.parse_known_args() + return config, unparsed + + +def print_usage(): + parser.print_usage() + +# +# config.py ends here \ No newline at end of file diff --git a/imcui/third_party/SGMNet/train/configs/sg.yaml b/imcui/third_party/SGMNet/train/configs/sg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb03f39f9d8445b1e345d8f8f6ac17eb6d981bc1 --- /dev/null +++ b/imcui/third_party/SGMNet/train/configs/sg.yaml @@ -0,0 +1,5 @@ +net_channels: 128 +layer_num: 9 +head: 4 +use_score_encoding: True +p_th: 0.2 \ No newline at end of file diff --git a/imcui/third_party/SGMNet/train/configs/sgm.yaml b/imcui/third_party/SGMNet/train/configs/sgm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d674adf562a8932192a0a3bb1a993cf90d28e989 --- /dev/null +++ b/imcui/third_party/SGMNet/train/configs/sgm.yaml @@ -0,0 +1,12 @@ +seed_top_k: [128,128] +seed_radius_coe: 0.01 +net_channels: 128 +layer_num: 9 +head: 4 +seedlayer: [0,6] +use_mc_seeding: True +use_score_encoding: False +conf_bar: [1,0.1] +sink_iter: [10,100] +detach_iter: 140000 +p_th: 0.2 \ No newline at end of file diff --git a/imcui/third_party/SGMNet/train/dataset.py b/imcui/third_party/SGMNet/train/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d07a84e9588b755a86119363f08860187d1668c0 --- /dev/null +++ b/imcui/third_party/SGMNet/train/dataset.py @@ -0,0 +1,143 @@ +import numpy as np +import torch +import torch.utils.data as data +import cv2 +import os +import h5py +import random + +import sys +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) +sys.path.insert(0, ROOT_DIR) + +from utils import train_utils,evaluation_utils + +torch.multiprocessing.set_sharing_strategy('file_system') + + +class Offline_Dataset(data.Dataset): + def __init__(self,config,mode): + assert mode=='train' or mode=='valid' + + self.config = config + self.mode = mode + metadir=os.path.join(config.dataset_path,'valid') if mode=='valid' else os.path.join(config.dataset_path,'train') + + pair_num_list=np.loadtxt(os.path.join(metadir,'pair_num.txt'),dtype=str) + self.total_pairs=int(pair_num_list[0,1]) + self.pair_seq_list,self.accu_pair_num=train_utils.parse_pair_seq(pair_num_list) + + + def collate_fn(self, batch): + batch_size, num_pts = len(batch), batch[0]['x1'].shape[0] + + data = {} + dtype=['x1','x2','kpt1','kpt2','desc1','desc2','num_corr','num_incorr1','num_incorr2','e_gt','pscore1','pscore2','img_path1','img_path2'] + for key in dtype: + data[key]=[] + for sample in batch: + for key in dtype: + data[key].append(sample[key]) + + for key in ['x1', 'x2','kpt1','kpt2', 'desc1', 'desc2','e_gt','pscore1','pscore2']: + data[key] = torch.from_numpy(np.stack(data[key])).float() + for key in ['num_corr', 'num_incorr1', 'num_incorr2']: + data[key] = torch.from_numpy(np.stack(data[key])).int() + + # kpt augmentation with random homography + if (self.mode == 'train' and self.config.data_aug): + homo_mat = torch.from_numpy(train_utils.get_rnd_homography(batch_size)).unsqueeze(1) + aug_seed=random.random() + if aug_seed<0.5: + x1_homo = torch.cat([data['x1'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) + x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1) + data['aug_x1'] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) + data['aug_x2']=data['x2'] + else: + x2_homo = torch.cat([data['x2'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) + x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1) + data['aug_x2'] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) + data['aug_x1']=data['x1'] + else: + data['aug_x1'],data['aug_x2']=data['x1'],data['x2'] + return data + + + def __getitem__(self, index): + seq=self.pair_seq_list[index] + index_within_seq=index-self.accu_pair_num[seq] + + with h5py.File(os.path.join(self.config.dataset_path,seq,'info.h5py'),'r') as data: + R,t = data['dR'][str(index_within_seq)][()], data['dt'][str(index_within_seq)][()] + egt = np.reshape(np.matmul(np.reshape(evaluation_utils.np_skew_symmetric(t.astype('float64').reshape(1, 3)), (3, 3)),np.reshape(R.astype('float64'), (3, 3))), (3, 3)) + egt = egt / np.linalg.norm(egt) + K1, K2 = data['K1'][str(index_within_seq)][()],data['K2'][str(index_within_seq)][()] + size1,size2=data['size1'][str(index_within_seq)][()],data['size2'][str(index_within_seq)][()] + + img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode() + img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1] + img_path1,img_path2=os.path.join(self.config.rawdata_path,img_path1),os.path.join(self.config.rawdata_path,img_path2) + fea_path1,fea_path2=os.path.join(self.config.desc_path,seq,img_name1+self.config.desc_suffix),\ + os.path.join(self.config.desc_path,seq,img_name2+self.config.desc_suffix) + with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2: + desc1,kpt1,pscore1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2],fea1['keypoints'][()][:,2] + desc2,kpt2,pscore2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2],fea2['keypoints'][()][:,2] + kpt1,kpt2,desc1,desc2=kpt1[:self.config.num_kpt],kpt2[:self.config.num_kpt],desc1[:self.config.num_kpt],desc2[:self.config.num_kpt] + + # normalize kpt + if self.config.input_normalize=='intrinsic': + x1, x2 = np.concatenate([kpt1, np.ones([kpt1.shape[0], 1])], axis=-1), np.concatenate( + [kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) + x1, x2 = np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2] + elif self.config.input_normalize=='img' : + x1,x2=(kpt1-size1/2)/size1,(kpt2-size2/2)/size2 + S1_inv,S2_inv=np.asarray([[size1[0],0,0.5*size1[0]],[0,size1[1],0.5*size1[1]],[0,0,1]]),\ + np.asarray([[size2[0],0,0.5*size2[0]],[0,size2[1],0.5*size2[1]],[0,0,1]]) + M1,M2=np.matmul(np.linalg.inv(K1),S1_inv),np.matmul(np.linalg.inv(K2),S2_inv) + egt=np.matmul(np.matmul(M2.transpose(),egt),M1) + egt = egt / np.linalg.norm(egt) + else: + raise NotImplementedError + + corr=data['corr'][str(index_within_seq)][()] + incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()] + + #permute kpt + valid_corr=corr[corr.max(axis=-1)= cur_kpt1): + sub_idx1 =np.random.choice(len(invalid_index1), cur_kpt1,replace=False) + if (invalid_index2.shape[0] < cur_kpt2): + sub_idx2 = np.concatenate([np.arange(len(invalid_index2)),np.random.randint(len(invalid_index2),size=cur_kpt2-len(invalid_index2))]) + if (invalid_index2.shape[0] >= cur_kpt2): + sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2,replace=False) + + per_idx1,per_idx2=np.concatenate([valid_corr[:,0],valid_incorr1,invalid_index1[sub_idx1]]),\ + np.concatenate([valid_corr[:,1],valid_incorr2,invalid_index2[sub_idx2]]) + + pscore1,pscore2=pscore1[per_idx1][:,np.newaxis],pscore2[per_idx2][:,np.newaxis] + x1,x2=x1[per_idx1][:,:2],x2[per_idx2][:,:2] + desc1,desc2=desc1[per_idx1],desc2[per_idx2] + kpt1,kpt2=kpt1[per_idx1],kpt2[per_idx2] + + return {'x1': x1, 'x2': x2, 'kpt1':kpt1,'kpt2':kpt2,'desc1': desc1, 'desc2': desc2, 'num_corr': num_corr, 'num_incorr1': num_incorr1,'num_incorr2': num_incorr2,'e_gt':egt,\ + 'pscore1':pscore1,'pscore2':pscore2,'img_path1':img_path1,'img_path2':img_path2} + + def __len__(self): + return self.total_pairs + + diff --git a/imcui/third_party/SGMNet/train/loss.py b/imcui/third_party/SGMNet/train/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fad4234fc5827321c31e72c08ad4a3466bad1c30 --- /dev/null +++ b/imcui/third_party/SGMNet/train/loss.py @@ -0,0 +1,125 @@ +import torch +import numpy as np + + +def batch_episym(x1, x2, F): + batch_size, num_pts = x1.shape[0], x1.shape[1] + x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1) + x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1) + F = F.reshape(-1,1,3,3).repeat(1,num_pts,1,1) + x2Fx1 = torch.matmul(x2.transpose(2,3), torch.matmul(F, x1)).reshape(batch_size,num_pts) + Fx1 = torch.matmul(F,x1).reshape(batch_size,num_pts,3) + Ftx2 = torch.matmul(F.transpose(2,3),x2).reshape(batch_size,num_pts,3) + ys = (x2Fx1**2 * ( + 1.0 / (Fx1[:, :, 0]**2 + Fx1[:, :, 1]**2 + 1e-15) + + 1.0 / (Ftx2[:, :, 0]**2 + Ftx2[:, :, 1]**2 + 1e-15))).sqrt() + return ys + + +def CELoss(seed_x1,seed_x2,e,confidence,inlier_th,batch_mask=1): + #seed_x: b*k*2 + ys=batch_episym(seed_x1,seed_x2,e) + mask_pos,mask_neg=(ys<=inlier_th).float(),(ys>inlier_th).float() + num_pos,num_neg=torch.relu(torch.sum(mask_pos, dim=1) - 1.0) + 1.0,torch.relu(torch.sum(mask_neg, dim=1) - 1.0) + 1.0 + loss_pos,loss_neg=-torch.log(abs(confidence) + 1e-8)*mask_pos,-torch.log(abs(1-confidence)+1e-8)*mask_neg + classif_loss = torch.mean(loss_pos * 0.5 / num_pos.unsqueeze(-1) + loss_neg * 0.5 / num_neg.unsqueeze(-1),dim=-1) + classif_loss =classif_loss*batch_mask + classif_loss=classif_loss.mean() + precision = torch.mean( + torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) / + (torch.sum((confidence > 0.5).type(confidence.type()), dim=1)+1e-8) + ) + recall = torch.mean( + torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) / + num_pos + ) + return classif_loss,precision,recall + + +def CorrLoss(desc_mat,batch_num_corr,batch_num_incorr1,batch_num_incorr2): + total_loss_corr,total_loss_incorr=0,0 + total_acc_corr,total_acc_incorr=0,0 + batch_size = desc_mat.shape[0] + log_p=torch.log(abs(desc_mat)+1e-8) + + for i in range(batch_size): + cur_log_p=log_p[i] + num_corr=batch_num_corr[i] + num_incorr1,num_incorr2=batch_num_incorr1[i],batch_num_incorr2[i] + + #loss and acc + loss_corr = -torch.diag(cur_log_p)[:num_corr].mean() + loss_incorr=(-cur_log_p[num_corr:num_corr+num_incorr1,-1].mean()-cur_log_p[-1,num_corr:num_corr+num_incorr2].mean())/2 + + value_row, row_index = torch.max(desc_mat[i,:-1,:-1], dim=-1) + value_col, col_index = torch.max(desc_mat[i,:-1,:-1], dim=-2) + acc_incorr=((value_row[num_corr:num_corr+num_incorr1]<0.2).float().mean()+ + (value_col[num_corr:num_corr+num_incorr2]<0.2).float().mean())/2 + + acc_row_mask = row_index[:num_corr] == torch.arange(num_corr).cuda() + acc_col_mask = col_index[:num_corr] == torch.arange(num_corr).cuda() + acc = (acc_col_mask & acc_row_mask).float().mean() + + total_loss_corr+=loss_corr + total_loss_incorr+=loss_incorr + total_acc_corr += acc + total_acc_incorr+=acc_incorr + + total_acc_corr/=batch_size + total_acc_incorr/=batch_size + total_loss_corr/=batch_size + total_loss_incorr/=batch_size + return total_loss_corr,total_loss_incorr,total_acc_corr,total_acc_incorr + + +class SGMLoss: + def __init__(self,config,model_config): + self.config=config + self.model_config=model_config + + def run(self,data,result): + loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2']) + loss_mid_corr_tower,loss_mid_incorr_tower,acc_mid_tower=[],[],[] + + #mid loss + for i in range(len(result['mid_p'])): + mid_p=result['mid_p'][i] + loss_mid_corr,loss_mid_incorr,mid_acc_corr,mid_acc_incorr=CorrLoss(mid_p,data['num_corr'],data['num_incorr1'],data['num_incorr2']) + loss_mid_corr_tower.append(loss_mid_corr),loss_mid_incorr_tower.append(loss_mid_incorr),acc_mid_tower.append(mid_acc_corr) + if len(result['mid_p']) != 0: + loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower = torch.stack(loss_mid_corr_tower), torch.stack(loss_mid_incorr_tower), torch.stack(acc_mid_tower) + else: + loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower= torch.zeros(1).cuda(), torch.zeros(1).cuda(),torch.zeros(1).cuda() + + #seed confidence loss + classif_loss_tower,classif_precision_tower,classif_recall_tower=[],[],[] + for layer in range(len(result['seed_conf'])): + confidence=result['seed_conf'][layer] + seed_index=result['seed_index'][(np.asarray(self.model_config.seedlayer)<=layer).nonzero()[0][-1]] + seed_x1,seed_x2=data['x1'].gather(dim=1, index=seed_index[:,:,0,None].expand(-1, -1,2)),\ + data['x2'].gather(dim=1, index=seed_index[:,:,1,None].expand(-1, -1,2)) + classif_loss,classif_precision,classif_recall=CELoss(seed_x1,seed_x2,data['e_gt'],confidence,self.config.inlier_th) + classif_loss_tower.append(classif_loss), classif_precision_tower.append(classif_precision), classif_recall_tower.append(classif_recall) + classif_loss, classif_precision_tower, classif_recall_tower=torch.stack(classif_loss_tower).mean(),torch.stack(classif_precision_tower), \ + torch.stack(classif_recall_tower) + + + classif_loss*=self.config.seed_loss_weight + loss_mid_corr_tower*=self.config.mid_loss_weight + loss_mid_incorr_tower*=self.config.mid_loss_weight + total_loss=loss_corr+loss_incorr+classif_loss+loss_mid_corr_tower.sum()+loss_mid_incorr_tower.sum() + + return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'loss_seed_conf':classif_loss, + 'pre_seed_conf':classif_precision_tower,'recall_seed_conf':classif_recall_tower,'loss_corr_mid':loss_mid_corr_tower, + 'loss_incorr_mid':loss_mid_incorr_tower,'mid_acc_corr':acc_mid_tower,'total_loss':total_loss} + +class SGLoss: + def __init__(self,config,model_config): + self.config=config + self.model_config=model_config + + def run(self,data,result): + loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2']) + total_loss=loss_corr+loss_incorr + return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'total_loss':total_loss} + \ No newline at end of file diff --git a/imcui/third_party/SGMNet/train/main.py b/imcui/third_party/SGMNet/train/main.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4c8fff432a3b2d58c82b9e5f2897a4e702b2dd --- /dev/null +++ b/imcui/third_party/SGMNet/train/main.py @@ -0,0 +1,61 @@ +import torch.utils.data +from dataset import Offline_Dataset +import yaml +from sgmnet.match_model import matcher as SGM_Model +from superglue.match_model import matcher as SG_Model +import torch.distributed as dist +import torch +import os +from collections import namedtuple +from train import train +from config import get_config, print_usage + + +def main(config,model_config): + """The main function.""" + # Initialize network + if config.model_name=='SGM': + model = SGM_Model(model_config) + elif config.model_name=='SG': + model= SG_Model(model_config) + else: + raise NotImplementedError + + #initialize ddp + torch.cuda.set_device(config.local_rank) + device = torch.device(f'cuda:{config.local_rank}') + model.to(device) + dist.init_process_group(backend='nccl',init_method='env://') + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank]) + + if config.local_rank==0: + os.system('nvidia-smi') + + #initialize dataset + train_dataset = Offline_Dataset(config,'train') + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True) + train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(), + num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn) + + valid_dataset = Offline_Dataset(config,'valid') + valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False) + valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size, + num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler) + + if config.local_rank==0: + print('start training .....') + train(model,train_loader, valid_loader, config,model_config) + +if __name__ == "__main__": + # ---------------------------------------- + # Parse configuration + config, unparsed = get_config() + with open(config.config_path, 'r') as f: + model_config = yaml.load(f) + model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) + # If we have unparsed arguments, print usage and exit + if len(unparsed) > 0: + print_usage() + exit(1) + + main(config,model_config) diff --git a/imcui/third_party/SGMNet/train/train.py b/imcui/third_party/SGMNet/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..31e848e1d2e5f028d4ff3abaf0cc446be7d89c65 --- /dev/null +++ b/imcui/third_party/SGMNet/train/train.py @@ -0,0 +1,160 @@ +import torch +import torch.optim as optim +from tqdm import trange +import os +from tensorboardX import SummaryWriter +import numpy as np +import cv2 +from loss import SGMLoss,SGLoss +from valid import valid,dump_train_vis + +import sys +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, ROOT_DIR) + + +from utils import train_utils + +def train_step(optimizer, model, match_loss, data,step,pre_avg_loss): + data['step']=step + result=model(data,test_mode=False) + loss_res=match_loss.run(data,result) + + optimizer.zero_grad() + loss_res['total_loss'].backward() + #apply reduce on all record tensor + for key in loss_res.keys(): + loss_res[key]=train_utils.reduce_tensor(loss_res[key],'mean') + + if loss_res['total_loss']<7*pre_avg_loss or step<200 or pre_avg_loss==0: + optimizer.step() + unusual_loss=False + else: + optimizer.zero_grad() + unusual_loss=True + return loss_res,unusual_loss + + +def train(model, train_loader, valid_loader, config,model_config): + model.train() + optimizer = optim.Adam(model.parameters(), lr=config.train_lr) + + if config.model_name=='SGM': + match_loss = SGMLoss(config,model_config) + elif config.model_name=='SG': + match_loss= SGLoss(config,model_config) + else: + raise NotImplementedError + + checkpoint_path = os.path.join(config.log_base, 'checkpoint.pth') + config.resume = os.path.isfile(checkpoint_path) + if config.resume: + if config.local_rank==0: + print('==> Resuming from checkpoint..') + checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(config.local_rank)) + model.load_state_dict(checkpoint['state_dict']) + best_acc = checkpoint['best_acc'] + start_step = checkpoint['step'] + optimizer.load_state_dict(checkpoint['optimizer']) + else: + best_acc = -1 + start_step = 0 + train_loader_iter = iter(train_loader) + + if config.local_rank==0: + writer=SummaryWriter(os.path.join(config.log_base,'log_file')) + + train_loader.sampler.set_epoch(start_step*config.train_batch_size//len(train_loader.dataset)) + pre_avg_loss=0 + + progress_bar=trange(start_step, config.train_iter,ncols=config.tqdm_width) if config.local_rank==0 else range(start_step, config.train_iter) + for step in progress_bar: + try: + train_data = next(train_loader_iter) + except StopIteration: + if config.local_rank==0: + print('epoch: ',step*config.train_batch_size//len(train_loader.dataset)) + train_loader.sampler.set_epoch(step*config.train_batch_size//len(train_loader.dataset)) + train_loader_iter = iter(train_loader) + train_data = next(train_loader_iter) + + train_data = train_utils.tocuda(train_data) + lr=min(config.train_lr*config.decay_rate**(step-config.decay_iter),config.train_lr) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # run training + loss_res,unusual_loss = train_step(optimizer, model, match_loss, train_data,step-start_step,pre_avg_loss) + if (step-start_step)<=200: + pre_avg_loss=loss_res['total_loss'].data + if (step-start_step)>200 and not unusual_loss: + pre_avg_loss=pre_avg_loss.data*0.9+loss_res['total_loss'].data*0.1 + if unusual_loss and config.local_rank==0: + print('unusual loss! pre_avg_loss: ',pre_avg_loss,'cur_loss: ',loss_res['total_loss'].data) + #log + if config.local_rank==0 and step%config.log_intv==0 and not unusual_loss: + writer.add_scalar('TotalLoss',loss_res['total_loss'],step) + writer.add_scalar('CorrLoss',loss_res['loss_corr'],step) + writer.add_scalar('InCorrLoss', loss_res['loss_incorr'], step) + writer.add_scalar('dustbin', model.module.dustbin, step) + + if config.model_name=='SGM': + writer.add_scalar('SeedConfLoss', loss_res['loss_seed_conf'], step) + writer.add_scalar('MidCorrLoss', loss_res['loss_corr_mid'].sum(), step) + writer.add_scalar('MidInCorrLoss', loss_res['loss_incorr_mid'].sum(), step) + + + # valid ans save + b_save = ((step + 1) % config.save_intv) == 0 + b_validate = ((step + 1) % config.val_intv) == 0 + if b_validate: + total_loss,acc_corr,acc_incorr,seed_precision_tower,seed_recall_tower,acc_mid=valid(valid_loader, model, match_loss, config,model_config) + if config.local_rank==0: + writer.add_scalar('ValidAcc', acc_corr, step) + writer.add_scalar('ValidLoss', total_loss, step) + + if config.model_name=='SGM': + for i in range(len(seed_recall_tower)): + writer.add_scalar('seed_conf_pre_%d'%i,seed_precision_tower[i],step) + writer.add_scalar('seed_conf_recall_%d' % i, seed_precision_tower[i], step) + for i in range(len(acc_mid)): + writer.add_scalar('acc_mid%d'%i,acc_mid[i],step) + print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data,'seed_conf_pre: ',seed_precision_tower.mean().data, + 'seed_conf_recall: ',seed_recall_tower.mean().data,'acc_mid: ',acc_mid.mean().data) + else: + print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data) + + #saving best + if acc_corr > best_acc: + print("Saving best model with va_res = {}".format(acc_corr)) + best_acc = acc_corr + save_dict={'step': step + 1, + 'state_dict': model.state_dict(), + 'best_acc': best_acc, + 'optimizer' : optimizer.state_dict()} + save_dict.update(save_dict) + torch.save(save_dict, os.path.join(config.log_base, 'model_best.pth')) + + if b_save: + if config.local_rank==0: + save_dict={'step': step + 1, + 'state_dict': model.state_dict(), + 'best_acc': best_acc, + 'optimizer' : optimizer.state_dict()} + torch.save(save_dict, checkpoint_path) + + #draw match results + model.eval() + with torch.no_grad(): + if config.local_rank==0: + if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis')): + os.mkdir(os.path.join(config.train_vis_folder,'train_vis')) + if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis',config.log_base)): + os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base)) + os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step))) + res=model(train_data) + dump_train_vis(res,train_data,step,config) + model.train() + + if config.local_rank==0: + writer.close() diff --git a/imcui/third_party/SGMNet/train/valid.py b/imcui/third_party/SGMNet/train/valid.py new file mode 100644 index 0000000000000000000000000000000000000000..443694d85104730cd50aeb342326ce593dc5684d --- /dev/null +++ b/imcui/third_party/SGMNet/train/valid.py @@ -0,0 +1,77 @@ +import torch +import numpy as np +import cv2 +import os +from loss import batch_episym +from tqdm import tqdm + +import sys +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, ROOT_DIR) + +from utils import evaluation_utils,train_utils + + +def valid(valid_loader, model,match_loss, config,model_config): + model.eval() + loader_iter = iter(valid_loader) + num_pair = 0 + total_loss,total_acc_corr,total_acc_incorr=0,0,0 + total_precision,total_recall=torch.zeros(model_config.layer_num ,device='cuda'),\ + torch.zeros(model_config.layer_num ,device='cuda') + total_acc_mid=torch.zeros(len(model_config.seedlayer)-1,device='cuda') + + with torch.no_grad(): + if config.local_rank==0: + loader_iter=tqdm(loader_iter) + print('validating...') + for test_data in loader_iter: + num_pair+= 1 + test_data = train_utils.tocuda(test_data) + res= model(test_data) + loss_res=match_loss.run(test_data,res) + + total_acc_corr+=loss_res['acc_corr'] + total_acc_incorr+=loss_res['acc_incorr'] + total_loss+=loss_res['total_loss'] + + if config.model_name=='SGM': + total_acc_mid+=loss_res['mid_acc_corr'] + total_precision,total_recall=total_precision+loss_res['pre_seed_conf'],total_recall+loss_res['recall_seed_conf'] + + total_acc_corr/=num_pair + total_acc_incorr /= num_pair + total_precision/=num_pair + total_recall/=num_pair + total_acc_mid/=num_pair + + #apply tensor reduction + total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid=train_utils.reduce_tensor(total_loss,'sum'),\ + train_utils.reduce_tensor(total_acc_corr,'mean'),train_utils.reduce_tensor(total_acc_incorr,'mean'),\ + train_utils.reduce_tensor(total_precision,'mean'),train_utils.reduce_tensor(total_recall,'mean'),train_utils.reduce_tensor(total_acc_mid,'mean') + model.train() + return total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid + + + +def dump_train_vis(res,data,step,config): + #batch matching + p=res['p'][:,:-1,:-1] + score,index1=torch.max(p,dim=-1) + _,index2=torch.max(p,dim=-2) + mask_th=score>0.2 + mask_mc=index2.gather(index=index1,dim=1) == torch.arange(len(p[0])).cuda()[None] + mask_p=mask_th&mask_mc#B*N + + corr1,corr2=data['x1'],data['x2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1) + corr1_kpt,corr2_kpt=data['kpt1'],data['kpt2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1) + epi_dis=batch_episym(corr1,corr2,data['e_gt']) + mask_inlier=epi_dis0,i0,j 0, + depth_top_right > 0 + ), + np.logical_and( + depth_down_left > 0, + depth_down_left > 0 + ) + ) + ids=ids[valid_depth] + depth_top_left,depth_top_right,depth_down_left,depth_down_right=depth_top_left[valid_depth],depth_top_right[valid_depth],\ + depth_down_left[valid_depth],depth_down_right[valid_depth] + + i,j,i_top_left,j_top_left=i[valid_depth],j[valid_depth],i_top_left[valid_depth],j_top_left[valid_depth] + + # Interpolation + dist_i_top_left = i - i_top_left.astype(np.float32) + dist_j_top_left = j - j_top_left.astype(np.float32) + w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) + w_top_right = (1 - dist_i_top_left) * dist_j_top_left + w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) + w_bottom_right = dist_i_top_left * dist_j_top_left + + interpolated_depth = ( + w_top_left * depth_top_left + + w_top_right * depth_top_right+ + w_bottom_left * depth_down_left + + w_bottom_right * depth_down_right + ) + return [interpolated_depth, ids] + + +def reprojection(depth_map,kpt,dR,dt,K1_img2depth,K1,K2): + #warp kpt from img1 to img2 + def swap_axis(data): + return np.stack([data[:, 1], data[:, 0]], axis=-1) + + kp_depth = unnorm_kp(K1_img2depth,kpt) + uv_depth = swap_axis(kp_depth) + z,valid_idx = interpolate_depth(uv_depth, depth_map) + + norm_kp=norm_kpt(K1,kpt) + norm_kp_valid = np.concatenate([norm_kp[valid_idx, :], np.ones((len(valid_idx), 1))], axis=-1) + xyz_valid = norm_kp_valid * z.reshape(-1, 1) + xyz2 = np.matmul(xyz_valid, dR.T) + dt.reshape(1, 3) + xy2 = xyz2[:, :2] / xyz2[:, 2:] + kp2, valid = np.ones(kpt.shape) * 1e5, np.zeros(kpt.shape[0]) + kp2[valid_idx] = unnorm_kp(K2,xy2) + valid[valid_idx] = 1 + return kp2, valid.astype(bool) + +def reprojection_2s(kp1, kp2,depth1, depth2, K1, K2, dR, dt, size1,size2): + #size:H*W + depth_size1,depth_size2 = [depth1.shape[0], depth1.shape[1]], [depth2.shape[0], depth2.shape[1]] + scale_1= [float(depth_size1[0]) / size1[0], float(depth_size1[1]) / size1[1], 1] + scale_2= [float(depth_size2[0]) / size2[0], float(depth_size2[1]) / size2[1], 1] + K1_img2depth, K2_img2depth = np.diag(np.asarray(scale_1)), np.diag(np.asarray(scale_2)) + kp1_2_proj, valid1_2 = reprojection(depth1, kp1, dR, dt, K1_img2depth,K1,K2) + kp2_1_proj, valid2_1 = reprojection(depth2, kp2, dR.T, -np.matmul(dR.T, dt), K2_img2depth,K2,K1) + return [kp1_2_proj,kp2_1_proj],[valid1_2,valid2_1] + +def make_corr(kp1,kp2,desc1,desc2,depth1,depth2,K1,K2,dR,dt,size1,size2,corr_th,incorr_th,check_desc=False): + #make reprojection + [kp1_2,kp2_1],[valid1_2,valid2_1]=reprojection_2s(kp1,kp2,depth1,depth2,K1,K2,dR,dt,size1,size2) + num_pts1, num_pts2 = kp1.shape[0], kp2.shape[0] + #reprojection error + dis_mat1=np.sqrt(abs((kp1 ** 2).sum(1,keepdims=True) + (kp2_1 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp1, kp2_1.T))) + dis_mat2 =np.sqrt(abs((kp2 ** 2).sum(1,keepdims=True) + (kp1_2 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp2,kp1_2.T))) + repro_error = np.maximum(dis_mat1,dis_mat2.T) #n1*n2 + + # find corr index + nn_sort1 = np.argmin(repro_error, axis=1) + nn_sort2 = np.argmin(repro_error, axis=0) + mask_mutual = nn_sort2[nn_sort1] == np.arange(kp1.shape[0]) + mask_inlier=np.take_along_axis(repro_error,indices=nn_sort1[:,np.newaxis],axis=-1).squeeze(1)1,mask_samepos2.sum(-1)>1) + duplicated_index=np.nonzero(duplicated_mask)[0] + + unique_corr_index=corr_index[~duplicated_mask] + clean_duplicated_corr=[] + for index in duplicated_index: + cur_desc1, cur_desc2 = desc1[mask_samepos1[index]], desc2[mask_samepos2[index]] + cur_desc_mat = np.matmul(cur_desc1, cur_desc2.T) + cur_max_index =[np.argmax(cur_desc_mat)//cur_desc_mat.shape[1],np.argmax(cur_desc_mat)%cur_desc_mat.shape[1]] + clean_duplicated_corr.append(np.stack([np.arange(num_pts1)[mask_samepos1[index]][cur_max_index[0]], + np.arange(num_pts2)[mask_samepos2[index]][cur_max_index[1]]])) + + clean_corr_index=unique_corr_index + if len(clean_duplicated_corr)!=0: + clean_duplicated_corr=np.stack(clean_duplicated_corr,axis=0) + clean_corr_index=np.concatenate([clean_corr_index,clean_duplicated_corr],axis=0) + else: + clean_corr_index=corr_index + # find incorr + mask_incorr1 = np.min(dis_mat2.T[valid1_2], axis=-1) > incorr_th + mask_incorr2 = np.min(dis_mat1.T[valid2_1], axis=-1) > incorr_th + incorr_index1, incorr_index2 = np.arange(num_pts1)[valid1_2][mask_incorr1.squeeze()], \ + np.arange(num_pts2)[valid2_1][mask_incorr2.squeeze()] + + return clean_corr_index,incorr_index1,incorr_index2 + diff --git a/imcui/third_party/SGMNet/utils/evaluation_utils.py b/imcui/third_party/SGMNet/utils/evaluation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82c4715a192d3c361c849896b035cd91ee56dc42 --- /dev/null +++ b/imcui/third_party/SGMNet/utils/evaluation_utils.py @@ -0,0 +1,58 @@ +import numpy as np +import h5py +import cv2 + +def normalize_intrinsic(x,K): + #print(x,K) + return (x-K[:2,2])/np.diag(K)[:2] + +def normalize_size(x,size,scale=1): + size=size.reshape([1,2]) + norm_fac=size.max() + return (x-size/2+0.5)/(norm_fac*scale) + +def np_skew_symmetric(v): + zero = np.zeros_like(v[:, 0]) + M = np.stack([ + zero, -v[:, 2], v[:, 1], + v[:, 2], zero, -v[:, 0], + -v[:, 1], v[:, 0], zero, + ], axis=1) + return M + +def draw_points(img,points,color=(0,255,0),radius=3): + dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])] + for i in range(points.shape[0]): + cv2.circle(img, dp[i],radius=radius,color=color) + return img + + +def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None): + if resize is not None: + scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]] + img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA) + corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis] + corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])] + corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])] + + assert len(corr1) == len(corr2) + + draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))] + if color is None: + color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier] + if len(color)==1: + display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None, + matchColor=color[0], + singlePointColor=color[0], + flags=4 + ) + else: + height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1] + display=np.zeros([height,width,3],np.uint8) + display[:img1.shape[0],:img1.shape[1]]=img1 + display[:img2.shape[0],img1.shape[1]:]=img2 + for i in range(len(corr1)): + left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1]) + cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2])) + cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA) + return display \ No newline at end of file diff --git a/imcui/third_party/SGMNet/utils/fm_utils.py b/imcui/third_party/SGMNet/utils/fm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9cbbeefe5d6b59c1ae1fa26cdaa42146ad22a74 --- /dev/null +++ b/imcui/third_party/SGMNet/utils/fm_utils.py @@ -0,0 +1,95 @@ +import numpy as np + + +def line_to_border(line,size): + #line:(a,b,c), ax+by+c=0 + #size:(W,H) + H,W=size[1],size[0] + a,b,c=line[0],line[1],line[2] + epsa=1e-8 if a>=0 else -1e-8 + epsb=1e-8 if b>=0 else -1e-8 + intersection_list=[] + + y_left=-c/(b+epsb) + y_right=(-c-a*(W-1))/(b+epsb) + x_top=-c/(a+epsa) + x_down=(-c-b*(H-1))/(a+epsa) + + if y_left>=0 and y_left<=H-1: + intersection_list.append([0,y_left]) + if y_right>=0 and y_right<=H-1: + intersection_list.append([W-1,y_right]) + if x_top>=0 and x_top<=W-1: + intersection_list.append([x_top,0]) + if x_down>=0 and x_down<=W-1: + intersection_list.append([x_down,H-1]) + if len(intersection_list)!=2: + return None + intersection_list=np.asarray(intersection_list) + return intersection_list + +def find_point_in_line(end_point): + x_span,y_span=end_point[1,0]-end_point[0,0],end_point[1,1]-end_point[0,1] + mv=np.random.uniform() + point=np.asarray([end_point[0,0]+x_span*mv,end_point[0,1]+y_span*mv]) + return point + +def epi_line(point,F): + homo=np.concatenate([point,np.ones([len(point),1])],axis=-1) + epi=np.matmul(homo,F.T) + return epi + +def dis_point_to_line(line,point): + homo=np.concatenate([point,np.ones([len(point),1])],axis=-1) + dis=line*homo + dis=dis.sum(axis=-1)/(np.linalg.norm(line[:,:2],axis=-1)+1e-8) + return abs(dis) + +def SGD_oneiter(F1,F2,size1,size2): + H1,W1=size1[1],size1[0] + factor1 = 1 / np.linalg.norm(size1) + factor2 = 1 / np.linalg.norm(size2) + p0=np.asarray([(W1-1)*np.random.uniform(),(H1-1)*np.random.uniform()]) + epi1=epi_line(p0[np.newaxis],F1)[0] + border_point1=line_to_border(epi1,size2) + if border_point1 is None: + return -1 + + p1=find_point_in_line(border_point1) + epi2=epi_line(p0[np.newaxis],F2) + d1=dis_point_to_line(epi2,p1[np.newaxis])[0]*factor2 + epi3=epi_line(p1[np.newaxis],F2.T) + d2=dis_point_to_line(epi3,p0[np.newaxis])[0]*factor1 + return (d1+d2)/2 + +def compute_SGD(F1,F2,size1,size2): + np.random.seed(1234) + N=1000 + max_iter=N*10 + count,sgd=0,0 + for i in range(max_iter): + d1=SGD_oneiter(F1,F2,size1,size2) + if d1<0: + continue + d2=SGD_oneiter(F2,F1,size1,size2) + if d2<0: + continue + count+=1 + sgd+=(d1+d2)/2 + if count==N: + break + if count==0: + return 1 + else: + return sgd/count + +def compute_inlier_rate(x1,x2,size1,size2,F_gt,th=0.003): + t1,t2=np.linalg.norm(size1)*th,np.linalg.norm(size2)*th + epi1,epi2=epi_line(x1,F_gt),epi_line(x2,F_gt.T) + dis1,dis2=dis_point_to_line(epi1,x2),dis_point_to_line(epi2,x1) + mask_inlier=np.logical_and(dis1`_ + +:Organization: + Laboratory for Fluorescence Dynamics, University of California, Irvine + +:Version: 2015.07.18 + +Requirements +------------ +* `CPython 2.7 or 3.4 `_ +* `Numpy 1.9 `_ +* `Transformations.c 2015.07.18 `_ + (recommended for speedup of some functions) + +Notes +----- +The API is not stable yet and is expected to change between revisions. + +This Python code is not optimized for speed. Refer to the transformations.c +module for a faster implementation of some functions. + +Documentation in HTML format can be generated with epydoc. + +Matrices (M) can be inverted using numpy.linalg.inv(M), be concatenated using +numpy.dot(M0, M1), or transform homogeneous coordinate arrays (v) using +numpy.dot(M, v) for shape (4, \*) column vectors, respectively +numpy.dot(v, M.T) for shape (\*, 4) row vectors ("array of points"). + +This module follows the "column vectors on the right" and "row major storage" +(C contiguous) conventions. The translation components are in the right column +of the transformation matrix, i.e. M[:3, 3]. +The transpose of the transformation matrices may have to be used to interface +with other graphics systems, e.g. with OpenGL's glMultMatrixd(). See also [16]. + +Calculations are carried out with numpy.float64 precision. + +Vector, point, quaternion, and matrix function arguments are expected to be +"array like", i.e. tuple, list, or numpy arrays. + +Return types are numpy arrays unless specified otherwise. + +Angles are in radians unless specified otherwise. + +Quaternions w+ix+jy+kz are represented as [w, x, y, z]. + +A triple of Euler angles can be applied/interpreted in 24 ways, which can +be specified using a 4 character string or encoded 4-tuple: + + *Axes 4-string*: e.g. 'sxyz' or 'ryxy' + + - first character : rotations are applied to 's'tatic or 'r'otating frame + - remaining characters : successive rotation axis 'x', 'y', or 'z' + + *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1) + + - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix. + - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed + by 'z', or 'z' is followed by 'x'. Otherwise odd (1). + - repetition : first and last axis are same (1) or different (0). + - frame : rotations are applied to static (0) or rotating (1) frame. + +Other Python packages and modules for 3D transformations and quaternions: + +* `Transforms3d `_ + includes most code of this module. +* `Blender.mathutils `_ +* `numpy-dtypes `_ + +References +---------- +(1) Matrices and transformations. Ronald Goldman. + In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990. +(2) More matrices and transformations: shear and pseudo-perspective. + Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. +(3) Decomposing a matrix into simple transformations. Spencer Thomas. + In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. +(4) Recovering the data from the transformation matrix. Ronald Goldman. + In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991. +(5) Euler angle conversion. Ken Shoemake. + In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994. +(6) Arcball rotation control. Ken Shoemake. + In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994. +(7) Representing attitude: Euler angles, unit quaternions, and rotation + vectors. James Diebel. 2006. +(8) A discussion of the solution for the best rotation to relate two sets + of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828. +(9) Closed-form solution of absolute orientation using unit quaternions. + BKP Horn. J Opt Soc Am A. 1987. 4(4):629-642. +(10) Quaternions. Ken Shoemake. + http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf +(11) From quaternion to matrix and back. JMP van Waveren. 2005. + http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm +(12) Uniform random rotations. Ken Shoemake. + In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992. +(13) Quaternion in molecular modeling. CFF Karney. + J Mol Graph Mod, 25(5):595-604 +(14) New method for extracting the quaternion from a rotation matrix. + Itzhack Y Bar-Itzhack, J Guid Contr Dynam. 2000. 23(6): 1085-1087. +(15) Multiple View Geometry in Computer Vision. Hartley and Zissermann. + Cambridge University Press; 2nd Ed. 2004. Chapter 4, Algorithm 4.7, p 130. +(16) Column Vectors vs. Row Vectors. + http://steve.hollasch.net/cgindex/math/matrix/column-vec.html + +Examples +-------- +>>> alpha, beta, gamma = 0.123, -1.234, 2.345 +>>> origin, xaxis, yaxis, zaxis = [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1] +>>> I = identity_matrix() +>>> Rx = rotation_matrix(alpha, xaxis) +>>> Ry = rotation_matrix(beta, yaxis) +>>> Rz = rotation_matrix(gamma, zaxis) +>>> R = concatenate_matrices(Rx, Ry, Rz) +>>> euler = euler_from_matrix(R, 'rxyz') +>>> numpy.allclose([alpha, beta, gamma], euler) +True +>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz') +>>> is_same_transform(R, Re) +True +>>> al, be, ga = euler_from_matrix(Re, 'rxyz') +>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz')) +True +>>> qx = quaternion_about_axis(alpha, xaxis) +>>> qy = quaternion_about_axis(beta, yaxis) +>>> qz = quaternion_about_axis(gamma, zaxis) +>>> q = quaternion_multiply(qx, qy) +>>> q = quaternion_multiply(q, qz) +>>> Rq = quaternion_matrix(q) +>>> is_same_transform(R, Rq) +True +>>> S = scale_matrix(1.23, origin) +>>> T = translation_matrix([1, 2, 3]) +>>> Z = shear_matrix(beta, xaxis, origin, zaxis) +>>> R = random_rotation_matrix(numpy.random.rand(3)) +>>> M = concatenate_matrices(T, R, Z, S) +>>> scale, shear, angles, trans, persp = decompose_matrix(M) +>>> numpy.allclose(scale, 1.23) +True +>>> numpy.allclose(trans, [1, 2, 3]) +True +>>> numpy.allclose(shear, [0, math.tan(beta), 0]) +True +>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles)) +True +>>> M1 = compose_matrix(scale, shear, angles, trans, persp) +>>> is_same_transform(M, M1) +True +>>> v0, v1 = random_vector(3), random_vector(3) +>>> M = rotation_matrix(angle_between_vectors(v0, v1), vector_product(v0, v1)) +>>> v2 = numpy.dot(v0, M[:3,:3].T) +>>> numpy.allclose(unit_vector(v1), unit_vector(v2)) +True + +""" + +from __future__ import division, print_function + +import math + +import numpy + +__version__ = '2015.07.18' +__docformat__ = 'restructuredtext en' +__all__ = () + + +def identity_matrix(): + """Return 4x4 identity/unit matrix. + + >>> I = identity_matrix() + >>> numpy.allclose(I, numpy.dot(I, I)) + True + >>> numpy.sum(I), numpy.trace(I) + (4.0, 4.0) + >>> numpy.allclose(I, numpy.identity(4)) + True + + """ + return numpy.identity(4) + + +def translation_matrix(direction): + """Return matrix to translate by direction vector. + + >>> v = numpy.random.random(3) - 0.5 + >>> numpy.allclose(v, translation_matrix(v)[:3, 3]) + True + + """ + M = numpy.identity(4) + M[:3, 3] = direction[:3] + return M + + +def translation_from_matrix(matrix): + """Return translation vector from translation matrix. + + >>> v0 = numpy.random.random(3) - 0.5 + >>> v1 = translation_from_matrix(translation_matrix(v0)) + >>> numpy.allclose(v0, v1) + True + + """ + return numpy.array(matrix, copy=False)[:3, 3].copy() + + +def reflection_matrix(point, normal): + """Return matrix to mirror at plane defined by point and normal vector. + + >>> v0 = numpy.random.random(4) - 0.5 + >>> v0[3] = 1. + >>> v1 = numpy.random.random(3) - 0.5 + >>> R = reflection_matrix(v0, v1) + >>> numpy.allclose(2, numpy.trace(R)) + True + >>> numpy.allclose(v0, numpy.dot(R, v0)) + True + >>> v2 = v0.copy() + >>> v2[:3] += v1 + >>> v3 = v0.copy() + >>> v2[:3] -= v1 + >>> numpy.allclose(v2, numpy.dot(R, v3)) + True + + """ + normal = unit_vector(normal[:3]) + M = numpy.identity(4) + M[:3, :3] -= 2.0 * numpy.outer(normal, normal) + M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal + return M + + +def reflection_from_matrix(matrix): + """Return mirror plane point and normal vector from reflection matrix. + + >>> v0 = numpy.random.random(3) - 0.5 + >>> v1 = numpy.random.random(3) - 0.5 + >>> M0 = reflection_matrix(v0, v1) + >>> point, normal = reflection_from_matrix(M0) + >>> M1 = reflection_matrix(point, normal) + >>> is_same_transform(M0, M1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + # normal: unit eigenvector corresponding to eigenvalue -1 + w, V = numpy.linalg.eig(M[:3, :3]) + i = numpy.where(abs(numpy.real(w) + 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue -1") + normal = numpy.real(V[:, i[0]]).squeeze() + # point: any unit eigenvector corresponding to eigenvalue 1 + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue 1") + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + return point, normal + + +def rotation_matrix(angle, direction, point=None): + """Return matrix to rotate about axis defined by point and direction. + + >>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0]) + >>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1]) + True + >>> angle = (random.random() - 0.5) * (2*math.pi) + >>> direc = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> R0 = rotation_matrix(angle, direc, point) + >>> R1 = rotation_matrix(angle-2*math.pi, direc, point) + >>> is_same_transform(R0, R1) + True + >>> R0 = rotation_matrix(angle, direc, point) + >>> R1 = rotation_matrix(-angle, -direc, point) + >>> is_same_transform(R0, R1) + True + >>> I = numpy.identity(4, numpy.float64) + >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc)) + True + >>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2, + ... direc, point))) + True + + """ + sina = math.sin(angle) + cosa = math.cos(angle) + direction = unit_vector(direction[:3]) + # rotation matrix around unit vector + R = numpy.diag([cosa, cosa, cosa]) + R += numpy.outer(direction, direction) * (1.0 - cosa) + direction *= sina + R += numpy.array([[ 0.0, -direction[2], direction[1]], + [ direction[2], 0.0, -direction[0]], + [-direction[1], direction[0], 0.0]]) + M = numpy.identity(4) + M[:3, :3] = R + if point is not None: + # rotation not around origin + point = numpy.array(point[:3], dtype=numpy.float64, copy=False) + M[:3, 3] = point - numpy.dot(R, point) + return M + + +def rotation_from_matrix(matrix): + """Return rotation angle and axis from rotation matrix. + + >>> angle = (random.random() - 0.5) * (2*math.pi) + >>> direc = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> R0 = rotation_matrix(angle, direc, point) + >>> angle, direc, point = rotation_from_matrix(R0) + >>> R1 = rotation_matrix(angle, direc, point) + >>> is_same_transform(R0, R1) + True + + """ + R = numpy.array(matrix, dtype=numpy.float64, copy=False) + R33 = R[:3, :3] + # direction: unit eigenvector of R33 corresponding to eigenvalue of 1 + w, W = numpy.linalg.eig(R33.T) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue 1") + direction = numpy.real(W[:, i[-1]]).squeeze() + # point: unit eigenvector of R33 corresponding to eigenvalue of 1 + w, Q = numpy.linalg.eig(R) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no unit eigenvector corresponding to eigenvalue 1") + point = numpy.real(Q[:, i[-1]]).squeeze() + point /= point[3] + # rotation angle depending on direction + cosa = (numpy.trace(R33) - 1.0) / 2.0 + if abs(direction[2]) > 1e-8: + sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2] + elif abs(direction[1]) > 1e-8: + sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1] + else: + sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0] + angle = math.atan2(sina, cosa) + return angle, direction, point + + +def scale_matrix(factor, origin=None, direction=None): + """Return matrix to scale by factor around origin in direction. + + Use factor -1 for point symmetry. + + >>> v = (numpy.random.rand(4, 5) - 0.5) * 20 + >>> v[3] = 1 + >>> S = scale_matrix(-1.234) + >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3]) + True + >>> factor = random.random() * 10 - 5 + >>> origin = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> S = scale_matrix(factor, origin) + >>> S = scale_matrix(factor, origin, direct) + + """ + if direction is None: + # uniform scaling + M = numpy.diag([factor, factor, factor, 1.0]) + if origin is not None: + M[:3, 3] = origin[:3] + M[:3, 3] *= 1.0 - factor + else: + # nonuniform scaling + direction = unit_vector(direction[:3]) + factor = 1.0 - factor + M = numpy.identity(4) + M[:3, :3] -= factor * numpy.outer(direction, direction) + if origin is not None: + M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction + return M + + +def scale_from_matrix(matrix): + """Return scaling factor, origin and direction from scaling matrix. + + >>> factor = random.random() * 10 - 5 + >>> origin = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> S0 = scale_matrix(factor, origin) + >>> factor, origin, direction = scale_from_matrix(S0) + >>> S1 = scale_matrix(factor, origin, direction) + >>> is_same_transform(S0, S1) + True + >>> S0 = scale_matrix(factor, origin, direct) + >>> factor, origin, direction = scale_from_matrix(S0) + >>> S1 = scale_matrix(factor, origin, direction) + >>> is_same_transform(S0, S1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + factor = numpy.trace(M33) - 2.0 + try: + # direction: unit eigenvector corresponding to eigenvalue factor + w, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(w) - factor) < 1e-8)[0][0] + direction = numpy.real(V[:, i]).squeeze() + direction /= vector_norm(direction) + except IndexError: + # uniform scaling + factor = (factor + 2.0) / 3.0 + direction = None + # origin: any eigenvector corresponding to eigenvalue 1 + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no eigenvector corresponding to eigenvalue 1") + origin = numpy.real(V[:, i[-1]]).squeeze() + origin /= origin[3] + return factor, origin, direction + + +def projection_matrix(point, normal, direction=None, + perspective=None, pseudo=False): + """Return matrix to project onto plane defined by point and normal. + + Using either perspective point, projection direction, or none of both. + + If pseudo is True, perspective projections will preserve relative depth + such that Perspective = dot(Orthogonal, PseudoPerspective). + + >>> P = projection_matrix([0, 0, 0], [1, 0, 0]) + >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:]) + True + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(3) - 0.5 + >>> P0 = projection_matrix(point, normal) + >>> P1 = projection_matrix(point, normal, direction=direct) + >>> P2 = projection_matrix(point, normal, perspective=persp) + >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True) + >>> is_same_transform(P2, numpy.dot(P0, P3)) + True + >>> P = projection_matrix([3, 0, 0], [1, 1, 0], [1, 0, 0]) + >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20 + >>> v0[3] = 1 + >>> v1 = numpy.dot(P, v0) + >>> numpy.allclose(v1[1], v0[1]) + True + >>> numpy.allclose(v1[0], 3-v1[1]) + True + + """ + M = numpy.identity(4) + point = numpy.array(point[:3], dtype=numpy.float64, copy=False) + normal = unit_vector(normal[:3]) + if perspective is not None: + # perspective projection + perspective = numpy.array(perspective[:3], dtype=numpy.float64, + copy=False) + M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal) + M[:3, :3] -= numpy.outer(perspective, normal) + if pseudo: + # preserve relative depth + M[:3, :3] -= numpy.outer(normal, normal) + M[:3, 3] = numpy.dot(point, normal) * (perspective+normal) + else: + M[:3, 3] = numpy.dot(point, normal) * perspective + M[3, :3] = -normal + M[3, 3] = numpy.dot(perspective, normal) + elif direction is not None: + # parallel projection + direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False) + scale = numpy.dot(direction, normal) + M[:3, :3] -= numpy.outer(direction, normal) / scale + M[:3, 3] = direction * (numpy.dot(point, normal) / scale) + else: + # orthogonal projection + M[:3, :3] -= numpy.outer(normal, normal) + M[:3, 3] = numpy.dot(point, normal) * normal + return M + + +def projection_from_matrix(matrix, pseudo=False): + """Return projection plane and perspective point from projection matrix. + + Return values are same as arguments for projection_matrix function: + point, normal, direction, perspective, and pseudo. + + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.random.random(3) - 0.5 + >>> direct = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(3) - 0.5 + >>> P0 = projection_matrix(point, normal) + >>> result = projection_from_matrix(P0) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, direct) + >>> result = projection_from_matrix(P0) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False) + >>> result = projection_from_matrix(P0, pseudo=False) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True) + >>> result = projection_from_matrix(P0, pseudo=True) + >>> P1 = projection_matrix(*result) + >>> is_same_transform(P0, P1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not pseudo and len(i): + # point: any eigenvector corresponding to eigenvalue 1 + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + # direction: unit eigenvector corresponding to eigenvalue 0 + w, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(w)) < 1e-8)[0] + if not len(i): + raise ValueError("no eigenvector corresponding to eigenvalue 0") + direction = numpy.real(V[:, i[0]]).squeeze() + direction /= vector_norm(direction) + # normal: unit eigenvector of M33.T corresponding to eigenvalue 0 + w, V = numpy.linalg.eig(M33.T) + i = numpy.where(abs(numpy.real(w)) < 1e-8)[0] + if len(i): + # parallel projection + normal = numpy.real(V[:, i[0]]).squeeze() + normal /= vector_norm(normal) + return point, normal, direction, None, False + else: + # orthogonal projection, where normal equals direction vector + return point, direction, None, None, False + else: + # perspective projection + i = numpy.where(abs(numpy.real(w)) > 1e-8)[0] + if not len(i): + raise ValueError( + "no eigenvector not corresponding to eigenvalue 0") + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + normal = - M[3, :3] + perspective = M[:3, 3] / numpy.dot(point[:3], normal) + if pseudo: + perspective -= normal + return point, normal, None, perspective, pseudo + + +def clip_matrix(left, right, bottom, top, near, far, perspective=False): + """Return matrix to obtain normalized device coordinates from frustum. + + The frustum bounds are axis-aligned along x (left, right), + y (bottom, top) and z (near, far). + + Normalized device coordinates are in range [-1, 1] if coordinates are + inside the frustum. + + If perspective is True the frustum is a truncated pyramid with the + perspective point at origin and direction along z axis, otherwise an + orthographic canonical view volume (a box). + + Homogeneous coordinates transformed by the perspective clip matrix + need to be dehomogenized (divided by w coordinate). + + >>> frustum = numpy.random.rand(6) + >>> frustum[1] += frustum[0] + >>> frustum[3] += frustum[2] + >>> frustum[5] += frustum[4] + >>> M = clip_matrix(perspective=False, *frustum) + >>> numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1]) + array([-1., -1., -1., 1.]) + >>> numpy.dot(M, [frustum[1], frustum[3], frustum[5], 1]) + array([ 1., 1., 1., 1.]) + >>> M = clip_matrix(perspective=True, *frustum) + >>> v = numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1]) + >>> v / v[3] + array([-1., -1., -1., 1.]) + >>> v = numpy.dot(M, [frustum[1], frustum[3], frustum[4], 1]) + >>> v / v[3] + array([ 1., 1., -1., 1.]) + + """ + if left >= right or bottom >= top or near >= far: + raise ValueError("invalid frustum") + if perspective: + if near <= _EPS: + raise ValueError("invalid frustum: near <= 0") + t = 2.0 * near + M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0], + [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0], + [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)], + [0.0, 0.0, -1.0, 0.0]] + else: + M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)], + [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)], + [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)], + [0.0, 0.0, 0.0, 1.0]] + return numpy.array(M) + + +def shear_matrix(angle, direction, point, normal): + """Return matrix to shear by angle along direction vector on shear plane. + + The shear plane is defined by a point and normal vector. The direction + vector must be orthogonal to the plane's normal vector. + + A point P is transformed by the shear matrix into P" such that + the vector P-P" is parallel to the direction vector and its extent is + given by the angle of P-P'-P", where P' is the orthogonal projection + of P onto the shear plane. + + >>> angle = (random.random() - 0.5) * 4*math.pi + >>> direct = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.cross(direct, numpy.random.random(3)) + >>> S = shear_matrix(angle, direct, point, normal) + >>> numpy.allclose(1, numpy.linalg.det(S)) + True + + """ + normal = unit_vector(normal[:3]) + direction = unit_vector(direction[:3]) + if abs(numpy.dot(normal, direction)) > 1e-6: + raise ValueError("direction and normal vectors are not orthogonal") + angle = math.tan(angle) + M = numpy.identity(4) + M[:3, :3] += angle * numpy.outer(direction, normal) + M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction + return M + + +def shear_from_matrix(matrix): + """Return shear angle, direction and plane from shear matrix. + + >>> angle = (random.random() - 0.5) * 4*math.pi + >>> direct = numpy.random.random(3) - 0.5 + >>> point = numpy.random.random(3) - 0.5 + >>> normal = numpy.cross(direct, numpy.random.random(3)) + >>> S0 = shear_matrix(angle, direct, point, normal) + >>> angle, direct, point, normal = shear_from_matrix(S0) + >>> S1 = shear_matrix(angle, direct, point, normal) + >>> is_same_transform(S0, S1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False) + M33 = M[:3, :3] + # normal: cross independent eigenvectors corresponding to the eigenvalue 1 + w, V = numpy.linalg.eig(M33) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-4)[0] + if len(i) < 2: + raise ValueError("no two linear independent eigenvectors found %s" % w) + V = numpy.real(V[:, i]).squeeze().T + lenorm = -1.0 + for i0, i1 in ((0, 1), (0, 2), (1, 2)): + n = numpy.cross(V[i0], V[i1]) + w = vector_norm(n) + if w > lenorm: + lenorm = w + normal = n + normal /= lenorm + # direction and angle + direction = numpy.dot(M33 - numpy.identity(3), normal) + angle = vector_norm(direction) + direction /= angle + angle = math.atan(angle) + # point: eigenvector corresponding to eigenvalue 1 + w, V = numpy.linalg.eig(M) + i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] + if not len(i): + raise ValueError("no eigenvector corresponding to eigenvalue 1") + point = numpy.real(V[:, i[-1]]).squeeze() + point /= point[3] + return angle, direction, point, normal + + +def decompose_matrix(matrix): + """Return sequence of transformations from transformation matrix. + + matrix : array_like + Non-degenerative homogeneous transformation matrix + + Return tuple of: + scale : vector of 3 scaling factors + shear : list of shear factors for x-y, x-z, y-z axes + angles : list of Euler angles about static x, y, z axes + translate : translation vector along x, y, z axes + perspective : perspective partition of matrix + + Raise ValueError if matrix is of wrong type or degenerative. + + >>> T0 = translation_matrix([1, 2, 3]) + >>> scale, shear, angles, trans, persp = decompose_matrix(T0) + >>> T1 = translation_matrix(trans) + >>> numpy.allclose(T0, T1) + True + >>> S = scale_matrix(0.123) + >>> scale, shear, angles, trans, persp = decompose_matrix(S) + >>> scale[0] + 0.123 + >>> R0 = euler_matrix(1, 2, 3) + >>> scale, shear, angles, trans, persp = decompose_matrix(R0) + >>> R1 = euler_matrix(*angles) + >>> numpy.allclose(R0, R1) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=True).T + if abs(M[3, 3]) < _EPS: + raise ValueError("M[3, 3] is zero") + M /= M[3, 3] + P = M.copy() + P[:, 3] = 0.0, 0.0, 0.0, 1.0 + if not numpy.linalg.det(P): + raise ValueError("matrix is singular") + + scale = numpy.zeros((3, )) + shear = [0.0, 0.0, 0.0] + angles = [0.0, 0.0, 0.0] + + if any(abs(M[:3, 3]) > _EPS): + perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T)) + M[:, 3] = 0.0, 0.0, 0.0, 1.0 + else: + perspective = numpy.array([0.0, 0.0, 0.0, 1.0]) + + translate = M[3, :3].copy() + M[3, :3] = 0.0 + + row = M[:3, :3].copy() + scale[0] = vector_norm(row[0]) + row[0] /= scale[0] + shear[0] = numpy.dot(row[0], row[1]) + row[1] -= row[0] * shear[0] + scale[1] = vector_norm(row[1]) + row[1] /= scale[1] + shear[0] /= scale[1] + shear[1] = numpy.dot(row[0], row[2]) + row[2] -= row[0] * shear[1] + shear[2] = numpy.dot(row[1], row[2]) + row[2] -= row[1] * shear[2] + scale[2] = vector_norm(row[2]) + row[2] /= scale[2] + shear[1:] /= scale[2] + + if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0: + numpy.negative(scale, scale) + numpy.negative(row, row) + + angles[1] = math.asin(-row[0, 2]) + if math.cos(angles[1]): + angles[0] = math.atan2(row[1, 2], row[2, 2]) + angles[2] = math.atan2(row[0, 1], row[0, 0]) + else: + #angles[0] = math.atan2(row[1, 0], row[1, 1]) + angles[0] = math.atan2(-row[2, 1], row[1, 1]) + angles[2] = 0.0 + + return scale, shear, angles, translate, perspective + + +def compose_matrix(scale=None, shear=None, angles=None, translate=None, + perspective=None): + """Return transformation matrix from sequence of transformations. + + This is the inverse of the decompose_matrix function. + + Sequence of transformations: + scale : vector of 3 scaling factors + shear : list of shear factors for x-y, x-z, y-z axes + angles : list of Euler angles about static x, y, z axes + translate : translation vector along x, y, z axes + perspective : perspective partition of matrix + + >>> scale = numpy.random.random(3) - 0.5 + >>> shear = numpy.random.random(3) - 0.5 + >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi) + >>> trans = numpy.random.random(3) - 0.5 + >>> persp = numpy.random.random(4) - 0.5 + >>> M0 = compose_matrix(scale, shear, angles, trans, persp) + >>> result = decompose_matrix(M0) + >>> M1 = compose_matrix(*result) + >>> is_same_transform(M0, M1) + True + + """ + M = numpy.identity(4) + if perspective is not None: + P = numpy.identity(4) + P[3, :] = perspective[:4] + M = numpy.dot(M, P) + if translate is not None: + T = numpy.identity(4) + T[:3, 3] = translate[:3] + M = numpy.dot(M, T) + if angles is not None: + R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz') + M = numpy.dot(M, R) + if shear is not None: + Z = numpy.identity(4) + Z[1, 2] = shear[2] + Z[0, 2] = shear[1] + Z[0, 1] = shear[0] + M = numpy.dot(M, Z) + if scale is not None: + S = numpy.identity(4) + S[0, 0] = scale[0] + S[1, 1] = scale[1] + S[2, 2] = scale[2] + M = numpy.dot(M, S) + M /= M[3, 3] + return M + + +def orthogonalization_matrix(lengths, angles): + """Return orthogonalization matrix for crystallographic cell coordinates. + + Angles are expected in degrees. + + The de-orthogonalization matrix is the inverse. + + >>> O = orthogonalization_matrix([10, 10, 10], [90, 90, 90]) + >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10) + True + >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7]) + >>> numpy.allclose(numpy.sum(O), 43.063229) + True + + """ + a, b, c = lengths + angles = numpy.radians(angles) + sina, sinb, _ = numpy.sin(angles) + cosa, cosb, cosg = numpy.cos(angles) + co = (cosa * cosb - cosg) / (sina * sinb) + return numpy.array([ + [ a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0], + [-a*sinb*co, b*sina, 0.0, 0.0], + [ a*cosb, b*cosa, c, 0.0], + [ 0.0, 0.0, 0.0, 1.0]]) + + +def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): + """Return affine transform matrix to register two point sets. + + v0 and v1 are shape (ndims, \*) arrays of at least ndims non-homogeneous + coordinates, where ndims is the dimensionality of the coordinate space. + + If shear is False, a similarity transformation matrix is returned. + If also scale is False, a rigid/Euclidean transformation matrix + is returned. + + By default the algorithm by Hartley and Zissermann [15] is used. + If usesvd is True, similarity and Euclidean transformation matrices + are calculated by minimizing the weighted sum of squared deviations + (RMSD) according to the algorithm by Kabsch [8]. + Otherwise, and if ndims is 3, the quaternion based algorithm by Horn [9] + is used, which is slower when using this Python implementation. + + The returned matrix performs rotation, translation and uniform scaling + (if specified). + + >>> v0 = [[0, 1031, 1031, 0], [0, 0, 1600, 1600]] + >>> v1 = [[675, 826, 826, 677], [55, 52, 281, 277]] + >>> affine_matrix_from_points(v0, v1) + array([[ 0.14549, 0.00062, 675.50008], + [ 0.00048, 0.14094, 53.24971], + [ 0. , 0. , 1. ]]) + >>> T = translation_matrix(numpy.random.random(3)-0.5) + >>> R = random_rotation_matrix(numpy.random.random(3)) + >>> S = scale_matrix(random.random()) + >>> M = concatenate_matrices(T, R, S) + >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20 + >>> v0[3] = 1 + >>> v1 = numpy.dot(M, v0) + >>> v0[:3] += numpy.random.normal(0, 1e-8, 300).reshape(3, -1) + >>> M = affine_matrix_from_points(v0[:3], v1[:3]) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + + More examples in superimposition_matrix() + + """ + v0 = numpy.array(v0, dtype=numpy.float64, copy=True) + v1 = numpy.array(v1, dtype=numpy.float64, copy=True) + + ndims = v0.shape[0] + if ndims < 2 or v0.shape[1] < ndims or v0.shape != v1.shape: + raise ValueError("input arrays are of wrong shape or type") + + # move centroids to origin + t0 = -numpy.mean(v0, axis=1) + M0 = numpy.identity(ndims+1) + M0[:ndims, ndims] = t0 + v0 += t0.reshape(ndims, 1) + t1 = -numpy.mean(v1, axis=1) + M1 = numpy.identity(ndims+1) + M1[:ndims, ndims] = t1 + v1 += t1.reshape(ndims, 1) + + if shear: + # Affine transformation + A = numpy.concatenate((v0, v1), axis=0) + u, s, vh = numpy.linalg.svd(A.T) + vh = vh[:ndims].T + B = vh[:ndims] + C = vh[ndims:2*ndims] + t = numpy.dot(C, numpy.linalg.pinv(B)) + t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1) + M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,))) + elif usesvd or ndims != 3: + # Rigid transformation via SVD of covariance matrix + u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T)) + # rotation matrix from SVD orthonormal bases + R = numpy.dot(u, vh) + if numpy.linalg.det(R) < 0.0: + # R does not constitute right handed system + R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0) + s[-1] *= -1.0 + # homogeneous transformation matrix + M = numpy.identity(ndims+1) + M[:ndims, :ndims] = R + else: + # Rigid transformation matrix via quaternion + # compute symmetric matrix N + xx, yy, zz = numpy.sum(v0 * v1, axis=1) + xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1) + xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1) + N = [[xx+yy+zz, 0.0, 0.0, 0.0], + [yz-zy, xx-yy-zz, 0.0, 0.0], + [zx-xz, xy+yx, yy-xx-zz, 0.0], + [xy-yx, zx+xz, yz+zy, zz-xx-yy]] + # quaternion: eigenvector corresponding to most positive eigenvalue + w, V = numpy.linalg.eigh(N) + q = V[:, numpy.argmax(w)] + q /= vector_norm(q) # unit quaternion + # homogeneous transformation matrix + M = quaternion_matrix(q) + + if scale and not shear: + # Affine transformation; scale is ratio of RMS deviations from centroid + v0 *= v0 + v1 *= v1 + M[:ndims, :ndims] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0)) + + # move centroids back + M = numpy.dot(numpy.linalg.inv(M1), numpy.dot(M, M0)) + M /= M[ndims, ndims] + return M + + +def superimposition_matrix(v0, v1, scale=False, usesvd=True): + """Return matrix to transform given 3D point set into second point set. + + v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 points. + + The parameters scale and usesvd are explained in the more general + affine_matrix_from_points function. + + The returned matrix is a similarity or Euclidean transformation matrix. + This function has a fast C implementation in transformations.c. + + >>> v0 = numpy.random.rand(3, 10) + >>> M = superimposition_matrix(v0, v0) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> R = random_rotation_matrix(numpy.random.random(3)) + >>> v0 = [[1,0,0], [0,1,0], [0,0,1], [1,1,1]] + >>> v1 = numpy.dot(R, v0) + >>> M = superimposition_matrix(v0, v1) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20 + >>> v0[3] = 1 + >>> v1 = numpy.dot(R, v0) + >>> M = superimposition_matrix(v0, v1) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> S = scale_matrix(random.random()) + >>> T = translation_matrix(numpy.random.random(3)-0.5) + >>> M = concatenate_matrices(T, R, S) + >>> v1 = numpy.dot(M, v0) + >>> v0[:3] += numpy.random.normal(0, 1e-9, 300).reshape(3, -1) + >>> M = superimposition_matrix(v0, v1, scale=True) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False) + >>> numpy.allclose(v1, numpy.dot(M, v0)) + True + >>> v = numpy.empty((4, 100, 3)) + >>> v[:, :, 0] = v0 + >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False) + >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0])) + True + + """ + v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3] + v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3] + return affine_matrix_from_points(v0, v1, shear=False, + scale=scale, usesvd=usesvd) + + +def euler_matrix(ai, aj, ak, axes='sxyz'): + """Return homogeneous rotation matrix from Euler angles and axis sequence. + + ai, aj, ak : Euler's roll, pitch and yaw angles + axes : One of 24 axis sequences as string or encoded tuple + + >>> R = euler_matrix(1, 2, 3, 'syxz') + >>> numpy.allclose(numpy.sum(R[0]), -1.34786452) + True + >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1)) + >>> numpy.allclose(numpy.sum(R[0]), -0.383436184) + True + >>> ai, aj, ak = (4*math.pi) * (numpy.random.random(3) - 0.5) + >>> for axes in _AXES2TUPLE.keys(): + ... R = euler_matrix(ai, aj, ak, axes) + >>> for axes in _TUPLE2AXES.keys(): + ... R = euler_matrix(ai, aj, ak, axes) + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes] + except (AttributeError, KeyError): + _TUPLE2AXES[axes] # validation + firstaxis, parity, repetition, frame = axes + + i = firstaxis + j = _NEXT_AXIS[i+parity] + k = _NEXT_AXIS[i-parity+1] + + if frame: + ai, ak = ak, ai + if parity: + ai, aj, ak = -ai, -aj, -ak + + si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak) + ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak) + cc, cs = ci*ck, ci*sk + sc, ss = si*ck, si*sk + + M = numpy.identity(4) + if repetition: + M[i, i] = cj + M[i, j] = sj*si + M[i, k] = sj*ci + M[j, i] = sj*sk + M[j, j] = -cj*ss+cc + M[j, k] = -cj*cs-sc + M[k, i] = -sj*ck + M[k, j] = cj*sc+cs + M[k, k] = cj*cc-ss + else: + M[i, i] = cj*ck + M[i, j] = sj*sc-cs + M[i, k] = sj*cc+ss + M[j, i] = cj*sk + M[j, j] = sj*ss+cc + M[j, k] = sj*cs-sc + M[k, i] = -sj + M[k, j] = cj*si + M[k, k] = cj*ci + return M + + +def euler_from_matrix(matrix, axes='sxyz'): + """Return Euler angles from rotation matrix for specified axis sequence. + + axes : One of 24 axis sequences as string or encoded tuple + + Note that many Euler angle triplets can describe one matrix. + + >>> R0 = euler_matrix(1, 2, 3, 'syxz') + >>> al, be, ga = euler_from_matrix(R0, 'syxz') + >>> R1 = euler_matrix(al, be, ga, 'syxz') + >>> numpy.allclose(R0, R1) + True + >>> angles = (4*math.pi) * (numpy.random.random(3) - 0.5) + >>> for axes in _AXES2TUPLE.keys(): + ... R0 = euler_matrix(axes=axes, *angles) + ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes)) + ... if not numpy.allclose(R0, R1): print(axes, "failed") + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] + except (AttributeError, KeyError): + _TUPLE2AXES[axes] # validation + firstaxis, parity, repetition, frame = axes + + i = firstaxis + j = _NEXT_AXIS[i+parity] + k = _NEXT_AXIS[i-parity+1] + + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3] + if repetition: + sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k]) + if sy > _EPS: + ax = math.atan2( M[i, j], M[i, k]) + ay = math.atan2( sy, M[i, i]) + az = math.atan2( M[j, i], -M[k, i]) + else: + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2( sy, M[i, i]) + az = 0.0 + else: + cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i]) + if cy > _EPS: + ax = math.atan2( M[k, j], M[k, k]) + ay = math.atan2(-M[k, i], cy) + az = math.atan2( M[j, i], M[i, i]) + else: + ax = math.atan2(-M[j, k], M[j, j]) + ay = math.atan2(-M[k, i], cy) + az = 0.0 + + if parity: + ax, ay, az = -ax, -ay, -az + if frame: + ax, az = az, ax + return ax, ay, az + + +def euler_from_quaternion(quaternion, axes='sxyz'): + """Return Euler angles from quaternion for specified axis sequence. + + >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0]) + >>> numpy.allclose(angles, [0.123, 0, 0]) + True + + """ + return euler_from_matrix(quaternion_matrix(quaternion), axes) + + +def quaternion_from_euler(ai, aj, ak, axes='sxyz'): + """Return quaternion from Euler angles and axis sequence. + + ai, aj, ak : Euler's roll, pitch and yaw angles + axes : One of 24 axis sequences as string or encoded tuple + + >>> q = quaternion_from_euler(1, 2, 3, 'ryxz') + >>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435]) + True + + """ + try: + firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] + except (AttributeError, KeyError): + _TUPLE2AXES[axes] # validation + firstaxis, parity, repetition, frame = axes + + i = firstaxis + 1 + j = _NEXT_AXIS[i+parity-1] + 1 + k = _NEXT_AXIS[i-parity] + 1 + + if frame: + ai, ak = ak, ai + if parity: + aj = -aj + + ai /= 2.0 + aj /= 2.0 + ak /= 2.0 + ci = math.cos(ai) + si = math.sin(ai) + cj = math.cos(aj) + sj = math.sin(aj) + ck = math.cos(ak) + sk = math.sin(ak) + cc = ci*ck + cs = ci*sk + sc = si*ck + ss = si*sk + + q = numpy.empty((4, )) + if repetition: + q[0] = cj*(cc - ss) + q[i] = cj*(cs + sc) + q[j] = sj*(cc + ss) + q[k] = sj*(cs - sc) + else: + q[0] = cj*cc + sj*ss + q[i] = cj*sc - sj*cs + q[j] = cj*ss + sj*cc + q[k] = cj*cs - sj*sc + if parity: + q[j] *= -1.0 + + return q + + +def quaternion_about_axis(angle, axis): + """Return quaternion for rotation about axis. + + >>> q = quaternion_about_axis(0.123, [1, 0, 0]) + >>> numpy.allclose(q, [0.99810947, 0.06146124, 0, 0]) + True + + """ + q = numpy.array([0.0, axis[0], axis[1], axis[2]]) + qlen = vector_norm(q) + if qlen > _EPS: + q *= math.sin(angle/2.0) / qlen + q[0] = math.cos(angle/2.0) + return q + + +def quaternion_matrix(quaternion): + """Return homogeneous rotation matrix from quaternion. + + >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0]) + >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0])) + True + >>> M = quaternion_matrix([1, 0, 0, 0]) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> M = quaternion_matrix([0, 1, 0, 0]) + >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1])) + True + + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + n = numpy.dot(q, q) + if n < _EPS: + return numpy.identity(4) + q *= math.sqrt(2.0 / n) + q = numpy.outer(q, q) + return numpy.array([ + [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0], + [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0], + [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0], + [ 0.0, 0.0, 0.0, 1.0]]) + + +def quaternion_from_matrix(matrix, isprecise=False): + """Return quaternion from rotation matrix. + + If isprecise is True, the input matrix is assumed to be a precise rotation + matrix and a faster algorithm is used. + + >>> q = quaternion_from_matrix(numpy.identity(4), True) + >>> numpy.allclose(q, [1, 0, 0, 0]) + True + >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1])) + >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0]) + True + >>> R = rotation_matrix(0.123, (1, 2, 3)) + >>> q = quaternion_from_matrix(R, True) + >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786]) + True + >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0], + ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611]) + True + >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0], + ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603]) + True + >>> R = random_rotation_matrix() + >>> q = quaternion_from_matrix(R) + >>> is_same_transform(R, quaternion_matrix(q)) + True + >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0) + >>> numpy.allclose(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + + """ + M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4] + if isprecise: + q = numpy.empty((4, )) + t = numpy.trace(M) + if t > M[3, 3]: + q[0] = t + q[3] = M[1, 0] - M[0, 1] + q[2] = M[0, 2] - M[2, 0] + q[1] = M[2, 1] - M[1, 2] + else: + i, j, k = 1, 2, 3 + if M[1, 1] > M[0, 0]: + i, j, k = 2, 3, 1 + if M[2, 2] > M[i, i]: + i, j, k = 3, 1, 2 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q *= 0.5 / math.sqrt(t * M[3, 3]) + else: + m00 = M[0, 0] + m01 = M[0, 1] + m02 = M[0, 2] + m10 = M[1, 0] + m11 = M[1, 1] + m12 = M[1, 2] + m20 = M[2, 0] + m21 = M[2, 1] + m22 = M[2, 2] + # symmetric matrix K + K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0], + [m01+m10, m11-m00-m22, 0.0, 0.0], + [m02+m20, m12+m21, m22-m00-m11, 0.0], + [m21-m12, m02-m20, m10-m01, m00+m11+m22]]) + K /= 3.0 + # quaternion is eigenvector of K that corresponds to largest eigenvalue + w, V = numpy.linalg.eigh(K) + q = V[[3, 0, 1, 2], numpy.argmax(w)] + if q[0] < 0.0: + numpy.negative(q, q) + return q + + +def quaternion_multiply(quaternion1, quaternion0): + """Return multiplication of two quaternions. + + >>> q = quaternion_multiply([4, 1, -2, 3], [8, -5, 6, 7]) + >>> numpy.allclose(q, [28, -44, -14, 48]) + True + + """ + w0, x0, y0, z0 = quaternion0 + w1, x1, y1, z1 = quaternion1 + return numpy.array([-x1*x0 - y1*y0 - z1*z0 + w1*w0, + x1*w0 + y1*z0 - z1*y0 + w1*x0, + -x1*z0 + y1*w0 + z1*x0 + w1*y0, + x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64) + + +def quaternion_conjugate(quaternion): + """Return conjugate of quaternion. + + >>> q0 = random_quaternion() + >>> q1 = quaternion_conjugate(q0) + >>> q1[0] == q0[0] and all(q1[1:] == -q0[1:]) + True + + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + numpy.negative(q[1:], q[1:]) + return q + + +def quaternion_inverse(quaternion): + """Return inverse of quaternion. + + >>> q0 = random_quaternion() + >>> q1 = quaternion_inverse(q0) + >>> numpy.allclose(quaternion_multiply(q0, q1), [1, 0, 0, 0]) + True + + """ + q = numpy.array(quaternion, dtype=numpy.float64, copy=True) + numpy.negative(q[1:], q[1:]) + return q / numpy.dot(q, q) + + +def quaternion_real(quaternion): + """Return real part of quaternion. + + >>> quaternion_real([3, 0, 1, 2]) + 3.0 + + """ + return float(quaternion[0]) + + +def quaternion_imag(quaternion): + """Return imaginary part of quaternion. + + >>> quaternion_imag([3, 0, 1, 2]) + array([ 0., 1., 2.]) + + """ + return numpy.array(quaternion[1:4], dtype=numpy.float64, copy=True) + + +def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True): + """Return spherical linear interpolation between two quaternions. + + >>> q0 = random_quaternion() + >>> q1 = random_quaternion() + >>> q = quaternion_slerp(q0, q1, 0) + >>> numpy.allclose(q, q0) + True + >>> q = quaternion_slerp(q0, q1, 1, 1) + >>> numpy.allclose(q, q1) + True + >>> q = quaternion_slerp(q0, q1, 0.5) + >>> angle = math.acos(numpy.dot(q0, q)) + >>> numpy.allclose(2, math.acos(numpy.dot(q0, q1)) / angle) or \ + numpy.allclose(2, math.acos(-numpy.dot(q0, q1)) / angle) + True + + """ + q0 = unit_vector(quat0[:4]) + q1 = unit_vector(quat1[:4]) + if fraction == 0.0: + return q0 + elif fraction == 1.0: + return q1 + d = numpy.dot(q0, q1) + if abs(abs(d) - 1.0) < _EPS: + return q0 + if shortestpath and d < 0.0: + # invert rotation + d = -d + numpy.negative(q1, q1) + angle = math.acos(d) + spin * math.pi + if abs(angle) < _EPS: + return q0 + isin = 1.0 / math.sin(angle) + q0 *= math.sin((1.0 - fraction) * angle) * isin + q1 *= math.sin(fraction * angle) * isin + q0 += q1 + return q0 + + +def random_quaternion(rand=None): + """Return uniform random unit quaternion. + + rand: array like or None + Three independent random variables that are uniformly distributed + between 0 and 1. + + >>> q = random_quaternion() + >>> numpy.allclose(1, vector_norm(q)) + True + >>> q = random_quaternion(numpy.random.random(3)) + >>> len(q.shape), q.shape[0]==4 + (1, True) + + """ + if rand is None: + rand = numpy.random.rand(3) + else: + assert len(rand) == 3 + r1 = numpy.sqrt(1.0 - rand[0]) + r2 = numpy.sqrt(rand[0]) + pi2 = math.pi * 2.0 + t1 = pi2 * rand[1] + t2 = pi2 * rand[2] + return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1, + numpy.cos(t1)*r1, numpy.sin(t2)*r2]) + + +def random_rotation_matrix(rand=None): + """Return uniform random rotation matrix. + + rand: array like + Three independent random variables that are uniformly distributed + between 0 and 1 for each returned quaternion. + + >>> R = random_rotation_matrix() + >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4)) + True + + """ + return quaternion_matrix(random_quaternion(rand)) + + +class Arcball(object): + """Virtual Trackball Control. + + >>> ball = Arcball() + >>> ball = Arcball(initial=numpy.identity(4)) + >>> ball.place([320, 320], 320) + >>> ball.down([500, 250]) + >>> ball.drag([475, 275]) + >>> R = ball.matrix() + >>> numpy.allclose(numpy.sum(R), 3.90583455) + True + >>> ball = Arcball(initial=[1, 0, 0, 0]) + >>> ball.place([320, 320], 320) + >>> ball.setaxes([1, 1, 0], [-1, 1, 0]) + >>> ball.constrain = True + >>> ball.down([400, 200]) + >>> ball.drag([200, 400]) + >>> R = ball.matrix() + >>> numpy.allclose(numpy.sum(R), 0.2055924) + True + >>> ball.next() + + """ + def __init__(self, initial=None): + """Initialize virtual trackball control. + + initial : quaternion or rotation matrix + + """ + self._axis = None + self._axes = None + self._radius = 1.0 + self._center = [0.0, 0.0] + self._vdown = numpy.array([0.0, 0.0, 1.0]) + self._constrain = False + if initial is None: + self._qdown = numpy.array([1.0, 0.0, 0.0, 0.0]) + else: + initial = numpy.array(initial, dtype=numpy.float64) + if initial.shape == (4, 4): + self._qdown = quaternion_from_matrix(initial) + elif initial.shape == (4, ): + initial /= vector_norm(initial) + self._qdown = initial + else: + raise ValueError("initial not a quaternion or matrix") + self._qnow = self._qpre = self._qdown + + def place(self, center, radius): + """Place Arcball, e.g. when window size changes. + + center : sequence[2] + Window coordinates of trackball center. + radius : float + Radius of trackball in window coordinates. + + """ + self._radius = float(radius) + self._center[0] = center[0] + self._center[1] = center[1] + + def setaxes(self, *axes): + """Set axes to constrain rotations.""" + if axes is None: + self._axes = None + else: + self._axes = [unit_vector(axis) for axis in axes] + + @property + def constrain(self): + """Return state of constrain to axis mode.""" + return self._constrain + + @constrain.setter + def constrain(self, value): + """Set state of constrain to axis mode.""" + self._constrain = bool(value) + + def down(self, point): + """Set initial cursor window coordinates and pick constrain-axis.""" + self._vdown = arcball_map_to_sphere(point, self._center, self._radius) + self._qdown = self._qpre = self._qnow + if self._constrain and self._axes is not None: + self._axis = arcball_nearest_axis(self._vdown, self._axes) + self._vdown = arcball_constrain_to_axis(self._vdown, self._axis) + else: + self._axis = None + + def drag(self, point): + """Update current cursor window coordinates.""" + vnow = arcball_map_to_sphere(point, self._center, self._radius) + if self._axis is not None: + vnow = arcball_constrain_to_axis(vnow, self._axis) + self._qpre = self._qnow + t = numpy.cross(self._vdown, vnow) + if numpy.dot(t, t) < _EPS: + self._qnow = self._qdown + else: + q = [numpy.dot(self._vdown, vnow), t[0], t[1], t[2]] + self._qnow = quaternion_multiply(q, self._qdown) + + def next(self, acceleration=0.0): + """Continue rotation in direction of last drag.""" + q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False) + self._qpre, self._qnow = self._qnow, q + + def matrix(self): + """Return homogeneous rotation matrix.""" + return quaternion_matrix(self._qnow) + + +def arcball_map_to_sphere(point, center, radius): + """Return unit sphere coordinates from window coordinates.""" + v0 = (point[0] - center[0]) / radius + v1 = (center[1] - point[1]) / radius + n = v0*v0 + v1*v1 + if n > 1.0: + # position outside of sphere + n = math.sqrt(n) + return numpy.array([v0/n, v1/n, 0.0]) + else: + return numpy.array([v0, v1, math.sqrt(1.0 - n)]) + + +def arcball_constrain_to_axis(point, axis): + """Return sphere point perpendicular to axis.""" + v = numpy.array(point, dtype=numpy.float64, copy=True) + a = numpy.array(axis, dtype=numpy.float64, copy=True) + v -= a * numpy.dot(a, v) # on plane + n = vector_norm(v) + if n > _EPS: + if v[2] < 0.0: + numpy.negative(v, v) + v /= n + return v + if a[2] == 1.0: + return numpy.array([1.0, 0.0, 0.0]) + return unit_vector([-a[1], a[0], 0.0]) + + +def arcball_nearest_axis(point, axes): + """Return axis, which arc is nearest to point.""" + point = numpy.array(point, dtype=numpy.float64, copy=False) + nearest = None + mx = -1.0 + for axis in axes: + t = numpy.dot(arcball_constrain_to_axis(point, axis), point) + if t > mx: + nearest = axis + mx = t + return nearest + + +# epsilon for testing whether a number is close to zero +_EPS = numpy.finfo(float).eps * 4.0 + +# axis sequences for Euler angles +_NEXT_AXIS = [1, 2, 0, 1] + +# map axes strings to/from tuples of inner axis, parity, repetition, frame +_AXES2TUPLE = { + 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0), + 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0), + 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0), + 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0), + 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1), + 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1), + 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1), + 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)} + +_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) + + +def vector_norm(data, axis=None, out=None): + """Return length, i.e. Euclidean norm, of ndarray along axis. + + >>> v = numpy.random.random(3) + >>> n = vector_norm(v) + >>> numpy.allclose(n, numpy.linalg.norm(v)) + True + >>> v = numpy.random.rand(6, 5, 3) + >>> n = vector_norm(v, axis=-1) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2))) + True + >>> n = vector_norm(v, axis=1) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) + True + >>> v = numpy.random.rand(5, 4, 3) + >>> n = numpy.empty((5, 3)) + >>> vector_norm(v, axis=1, out=n) + >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) + True + >>> vector_norm([]) + 0.0 + >>> vector_norm([1]) + 1.0 + + """ + data = numpy.array(data, dtype=numpy.float64, copy=True) + if out is None: + if data.ndim == 1: + return math.sqrt(numpy.dot(data, data)) + data *= data + out = numpy.atleast_1d(numpy.sum(data, axis=axis)) + numpy.sqrt(out, out) + return out + else: + data *= data + numpy.sum(data, axis=axis, out=out) + numpy.sqrt(out, out) + + +def unit_vector(data, axis=None, out=None): + """Return ndarray normalized by length, i.e. Euclidean norm, along axis. + + >>> v0 = numpy.random.random(3) + >>> v1 = unit_vector(v0) + >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0)) + True + >>> v0 = numpy.random.rand(5, 4, 3) + >>> v1 = unit_vector(v0, axis=-1) + >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2) + >>> numpy.allclose(v1, v2) + True + >>> v1 = unit_vector(v0, axis=1) + >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1) + >>> numpy.allclose(v1, v2) + True + >>> v1 = numpy.empty((5, 4, 3)) + >>> unit_vector(v0, axis=1, out=v1) + >>> numpy.allclose(v1, v2) + True + >>> list(unit_vector([])) + [] + >>> list(unit_vector([1])) + [1.0] + + """ + if out is None: + data = numpy.array(data, dtype=numpy.float64, copy=True) + if data.ndim == 1: + data /= math.sqrt(numpy.dot(data, data)) + return data + else: + if out is not data: + out[:] = numpy.array(data, copy=False) + data = out + length = numpy.atleast_1d(numpy.sum(data*data, axis)) + numpy.sqrt(length, length) + if axis is not None: + length = numpy.expand_dims(length, axis) + data /= length + if out is None: + return data + + +def random_vector(size): + """Return array of random doubles in the half-open interval [0.0, 1.0). + + >>> v = random_vector(10000) + >>> numpy.all(v >= 0) and numpy.all(v < 1) + True + >>> v0 = random_vector(10) + >>> v1 = random_vector(10) + >>> numpy.any(v0 == v1) + False + + """ + return numpy.random.random(size) + + +def vector_product(v0, v1, axis=0): + """Return vector perpendicular to vectors. + + >>> v = vector_product([2, 0, 0], [0, 3, 0]) + >>> numpy.allclose(v, [0, 0, 6]) + True + >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]] + >>> v1 = [[3], [0], [0]] + >>> v = vector_product(v0, v1) + >>> numpy.allclose(v, [[0, 0, 0, 0], [0, 0, 6, 6], [0, -6, 0, -6]]) + True + >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]] + >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]] + >>> v = vector_product(v0, v1, axis=1) + >>> numpy.allclose(v, [[0, 0, 6], [0, -6, 0], [6, 0, 0], [0, -6, 6]]) + True + + """ + return numpy.cross(v0, v1, axis=axis) + + +def angle_between_vectors(v0, v1, directed=True, axis=0): + """Return angle between vectors. + + If directed is False, the input vectors are interpreted as undirected axes, + i.e. the maximum angle is pi/2. + + >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3]) + >>> numpy.allclose(a, math.pi) + True + >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3], directed=False) + >>> numpy.allclose(a, 0) + True + >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]] + >>> v1 = [[3], [0], [0]] + >>> a = angle_between_vectors(v0, v1) + >>> numpy.allclose(a, [0, 1.5708, 1.5708, 0.95532]) + True + >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]] + >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]] + >>> a = angle_between_vectors(v0, v1, axis=1) + >>> numpy.allclose(a, [1.5708, 1.5708, 1.5708, 0.95532]) + True + + """ + v0 = numpy.array(v0, dtype=numpy.float64, copy=False) + v1 = numpy.array(v1, dtype=numpy.float64, copy=False) + dot = numpy.sum(v0 * v1, axis=axis) + dot /= vector_norm(v0, axis=axis) * vector_norm(v1, axis=axis) + return numpy.arccos(dot if directed else numpy.fabs(dot)) + + +def inverse_matrix(matrix): + """Return inverse of square transformation matrix. + + >>> M0 = random_rotation_matrix() + >>> M1 = inverse_matrix(M0.T) + >>> numpy.allclose(M1, numpy.linalg.inv(M0.T)) + True + >>> for size in range(1, 7): + ... M0 = numpy.random.rand(size, size) + ... M1 = inverse_matrix(M0) + ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print(size) + + """ + return numpy.linalg.inv(matrix) + + +def concatenate_matrices(*matrices): + """Return concatenation of series of transformation matrices. + + >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5 + >>> numpy.allclose(M, concatenate_matrices(M)) + True + >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T)) + True + + """ + M = numpy.identity(4) + for i in matrices: + M = numpy.dot(M, i) + return M + + +def is_same_transform(matrix0, matrix1): + """Return True if two matrices perform same transformation. + + >>> is_same_transform(numpy.identity(4), numpy.identity(4)) + True + >>> is_same_transform(numpy.identity(4), random_rotation_matrix()) + False + + """ + matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True) + matrix0 /= matrix0[3, 3] + matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True) + matrix1 /= matrix1[3, 3] + return numpy.allclose(matrix0, matrix1) + + +def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'): + """Try import all public attributes from module into global namespace. + + Existing attributes with name clashes are renamed with prefix. + Attributes starting with underscore are ignored by default. + + Return True on successful import. + + """ + import warnings + from importlib import import_module + try: + if not package: + module = import_module(name) + else: + module = import_module('.' + name, package=package) + except ImportError: + if warn: + #warnings.warn("failed to import module %s" % name) + pass + else: + for attr in dir(module): + if ignore and attr.startswith(ignore): + continue + if prefix: + if attr in globals(): + globals()[prefix + attr] = globals()[attr] + elif warn: + warnings.warn("no Python implementation of " + attr) + globals()[attr] = getattr(module, attr) + return True + + +_import_module('_transformations') + +if __name__ == "__main__": + import doctest + import random # used in doctests + numpy.set_printoptions(suppress=True, precision=5) + doctest.testmod() + diff --git a/imcui/third_party/SOLD2/notebooks/__init__.py b/imcui/third_party/SOLD2/notebooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/setup.py b/imcui/third_party/SOLD2/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..69f72fecdc54cf9b43a7fc55144470e83c5a862d --- /dev/null +++ b/imcui/third_party/SOLD2/setup.py @@ -0,0 +1,4 @@ +from setuptools import setup + + +setup(name='sold2', version="0.0", packages=['sold2']) diff --git a/imcui/third_party/SOLD2/sold2/config/__init__.py b/imcui/third_party/SOLD2/sold2/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/sold2/config/export_line_features.yaml b/imcui/third_party/SOLD2/sold2/config/export_line_features.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f19c7b6d684b7a826d6f2909b8c9f94528fdbf94 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/export_line_features.yaml @@ -0,0 +1,80 @@ +### [Model config] +model_cfg: + ### [Model parameters] + model_name: "lcnn_simple" + model_architecture: "simple" + # Backbone related config + backbone: "lcnn" + backbone_cfg: + input_channel: 1 # Use RGB images or grayscale images. + depth: 4 + num_stacks: 2 + num_blocks: 1 + num_classes: 5 + # Junction decoder related config + junction_decoder: "superpoint_decoder" + junc_decoder_cfg: + # Heatmap decoder related config + heatmap_decoder: "pixel_shuffle" + heatmap_decoder_cfg: + # Descriptor decoder related config + descriptor_decoder: "superpoint_descriptor" + descriptor_decoder_cfg: + # Shared configurations + grid_size: 8 + keep_border_valid: True + # Threshold of junction detection + detection_thresh: 0.0153846 # 1/65 + max_num_junctions: 300 + # Threshold of heatmap detection + prob_thresh: 0.5 + + ### [Loss parameters] + weighting_policy: "dynamic" + # [Heatmap loss] + w_heatmap: 0. + w_heatmap_class: 1 + heatmap_loss_func: "cross_entropy" + heatmap_loss_cfg: + policy: "dynamic" + # [Junction loss] + w_junc: 0. + junction_loss_func: "superpoint" + junction_loss_cfg: + policy: "dynamic" + # [Descriptor loss] + w_desc: 0. + descriptor_loss_func: "regular_sampling" + descriptor_loss_cfg: + dist_threshold: 8 + grid_size: 4 + margin: 1 + policy: "dynamic" + +### [Line detector config] +line_detector_cfg: + detect_thresh: 0.5 + num_samples: 64 + sampling_method: "local_max" + inlier_thresh: 0.99 + use_candidate_suppression: True + nms_dist_tolerance: 3. + use_heatmap_refinement: True + heatmap_refine_cfg: + mode: "local" + ratio: 0.2 + valid_thresh: 0.001 + num_blocks: 20 + overlap_ratio: 0.5 + use_junction_refinement: True + junction_refine_cfg: + num_perturbs: 9 + perturb_interval: 0.25 + +### [Line matcher config] +line_matcher_cfg: + cross_check: True + num_samples: 5 + min_dist_pts: 8 + top_k_candidates: 10 + grid_size: 4 \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/config/holicity_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/holicity_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72e9380dbf496dc4b4d6430d58534e0663c85f0e --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/holicity_dataset.yaml @@ -0,0 +1,76 @@ +### General dataset parameters +dataset_name: "holicity" +train_splits: ["2018-01"] # 5720 images +add_augmentation_to_all_splits: False +gray_scale: True +# Ground truth source ('official' or path to the exported h5 dataset.) +#gt_source_train: "" # Fill with your own export file +#gt_source_test: "" # Fill with your own export file +# Return type: (1) single (to train the detector only) +# or (2) paired_desc (to train the detector + descriptor) +return_type: "single" +random_seed: 0 + +### Descriptor training parameters +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 + +### Data preprocessing configuration +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: True + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 + +### Homography adaptation configuration +homography_adaptation: + num_iter: 100 + valid_border_margin: 3 + min_counts: 30 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/config/merge_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/merge_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f70465b71e507cbc9f258a8bbf45f41e435ee9b0 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/merge_dataset.yaml @@ -0,0 +1,54 @@ +dataset_name: "merge" +datasets: ["wireframe", "holicity"] +weights: [0.5, 0.5] +gt_source_train: ["", ""] # Fill with your own [wireframe, holicity] exported ground-truth +gt_source_test: ["", ""] # Fill with your own [wireframe, holicity] exported ground-truth +train_splits: ["", "2018-01"] +add_augmentation_to_all_splits: False +gray_scale: True +# Return type: (1) single (original version) (2) paired +return_type: "paired_desc" +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 +# Random seed +random_seed: 0 +# Date preprocessing configuration. +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + photometric: + enable: True + primitives: [ + 'random_brightness', 'random_contrast', 'additive_speckle_noise', + 'additive_gaussian_noise', 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 diff --git a/imcui/third_party/SOLD2/sold2/config/project_config.py b/imcui/third_party/SOLD2/sold2/config/project_config.py new file mode 100644 index 0000000000000000000000000000000000000000..42ed00d1c1900e71568d1b06ff4f9d19a295232d --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/project_config.py @@ -0,0 +1,41 @@ +""" +Project configurations. +""" +import os + + +class Config(object): + """ Datasets and experiments folders for the whole project. """ + ##################### + ## Dataset setting ## + ##################### + DATASET_ROOT = os.getenv("DATASET_ROOT", "./datasets/") # TODO: path to your datasets folder + if not os.path.exists(DATASET_ROOT): + os.makedirs(DATASET_ROOT) + + # Synthetic shape dataset + synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes") + synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes") + if not os.path.exists(synthetic_dataroot): + os.makedirs(synthetic_dataroot) + + # Exported predictions dataset + export_dataroot = os.path.join(DATASET_ROOT, "export_datasets") + export_cache_path = os.path.join(DATASET_ROOT, "export_datasets") + if not os.path.exists(export_dataroot): + os.makedirs(export_dataroot) + + # Wireframe dataset + wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe") + wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe") + + # Holicity dataset + holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity") + holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity") + + ######################## + ## Experiment Setting ## + ######################## + EXP_PATH = os.getenv("EXP_PATH", "./experiments/") # TODO: path to your experiments folder + if not os.path.exists(EXP_PATH): + os.makedirs(EXP_PATH) diff --git a/imcui/third_party/SOLD2/sold2/config/synthetic_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/synthetic_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9fa44522b6c09500100dbc56a11bc8a24d56832 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/synthetic_dataset.yaml @@ -0,0 +1,48 @@ +### General dataset parameters +dataset_name: "synthetic_shape" +primitives: "all" +add_augmentation_to_all_splits: True +test_augmentation_seed: 200 +# Shape generation configuration +generation: + split_sizes: {'train': 20000, 'val': 2000, 'test': 400} + random_seed: 10 + image_size: [960, 1280] + min_len: 0.0985 + min_label_len: 0.099 + params: + generate_background: + min_kernel_size: 150 + max_kernel_size: 500 + min_rad_ratio: 0.02 + max_rad_ratio: 0.031 + draw_stripes: + transform_params: [0.1, 0.1] + draw_multiple_polygons: + kernel_boundaries: [50, 100] + +### Data preprocessing configuration. +preprocessing: + resize: [400, 400] + blur_size: 11 +augmentation: + photometric: + enable: True + primitives: 'all' + params: {} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.8 + max_angle: 1.57 + allow_artifacts: true + translation_overflow: 0.05 + valid_border_margin: 0 diff --git a/imcui/third_party/SOLD2/sold2/config/train_detector.yaml b/imcui/third_party/SOLD2/sold2/config/train_detector.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c53c35a6464eb1c37a9ea71c939225f793543aec --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/train_detector.yaml @@ -0,0 +1,51 @@ +### [Model parameters] +model_name: "lcnn_simple" +model_architecture: "simple" +# Backbone related config +backbone: "lcnn" +backbone_cfg: + input_channel: 1 # Use RGB images or grayscale images. + depth: 4 + num_stacks: 2 + num_blocks: 1 + num_classes: 5 +# Junction decoder related config +junction_decoder: "superpoint_decoder" +junc_decoder_cfg: +# Heatmap decoder related config +heatmap_decoder: "pixel_shuffle" +heatmap_decoder_cfg: +# Shared configurations +grid_size: 8 +keep_border_valid: True +# Threshold of junction detection +detection_thresh: 0.0153846 # 1/65 +# Threshold of heatmap detection +prob_thresh: 0.5 + +### [Loss parameters] +weighting_policy: "dynamic" +# [Heatmap loss] +w_heatmap: 0. +w_heatmap_class: 1 +heatmap_loss_func: "cross_entropy" +heatmap_loss_cfg: + policy: "dynamic" +# [Junction loss] +w_junc: 0. +junction_loss_func: "superpoint" +junction_loss_cfg: + policy: "dynamic" + +### [Training parameters] +learning_rate: 0.0005 +epochs: 200 +train: + batch_size: 6 + num_workers: 8 +test: + batch_size: 6 + num_workers: 8 +disp_freq: 100 +summary_freq: 200 +max_ckpt: 150 \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/config/train_full_pipeline.yaml b/imcui/third_party/SOLD2/sold2/config/train_full_pipeline.yaml new file mode 100644 index 0000000000000000000000000000000000000000..233d898f47110c14beabbe63ee82044d506cc15a --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/train_full_pipeline.yaml @@ -0,0 +1,62 @@ +### [Model parameters] +model_name: "lcnn_simple" +model_architecture: "simple" +# Backbone related config +backbone: "lcnn" +backbone_cfg: + input_channel: 1 # Use RGB images or grayscale images. + depth: 4 + num_stacks: 2 + num_blocks: 1 + num_classes: 5 +# Junction decoder related config +junction_decoder: "superpoint_decoder" +junc_decoder_cfg: +# Heatmap decoder related config +heatmap_decoder: "pixel_shuffle" +heatmap_decoder_cfg: +# Descriptor decoder related config +descriptor_decoder: "superpoint_descriptor" +descriptor_decoder_cfg: +# Shared configurations +grid_size: 8 +keep_border_valid: True +# Threshold of junction detection +detection_thresh: 0.0153846 # 1/65 +# Threshold of heatmap detection +prob_thresh: 0.5 + +### [Loss parameters] +weighting_policy: "dynamic" +# [Heatmap loss] +w_heatmap: 0. +w_heatmap_class: 1 +heatmap_loss_func: "cross_entropy" +heatmap_loss_cfg: + policy: "dynamic" +# [Junction loss] +w_junc: 0. +junction_loss_func: "superpoint" +junction_loss_cfg: + policy: "dynamic" +# [Descriptor loss] +w_desc: 0. +descriptor_loss_func: "regular_sampling" +descriptor_loss_cfg: + dist_threshold: 8 + grid_size: 4 + margin: 1 + policy: "dynamic" + +### [Training parameters] +learning_rate: 0.0005 +epochs: 130 +train: + batch_size: 4 + num_workers: 8 +test: + batch_size: 4 + num_workers: 8 +disp_freq: 100 +summary_freq: 200 +max_ckpt: 130 \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/config/wireframe_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/wireframe_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15abd3dbd6462dca21ac331a802b86a8ef050bff --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/config/wireframe_dataset.yaml @@ -0,0 +1,75 @@ +### General dataset parameters +dataset_name: "wireframe" +add_augmentation_to_all_splits: False +gray_scale: True +# Ground truth source ('official' or path to the exported h5 dataset.) +# gt_source_train: "" # Fill with your own export file +# gt_source_test: "" # Fill with your own export file +# Return type: (1) single (to train the detector only) +# or (2) paired_desc (to train the detector + descriptor) +return_type: "single" +random_seed: 0 + +### Descriptor training parameters +# Number of points extracted per line +max_num_samples: 10 +# Max number of training line points extracted in the whole image +max_pts: 1000 +# Min distance between two points on a line (in pixels) +min_dist_pts: 10 +# Small jittering of the sampled points during training +jittering: 0 + +### Data preprocessing configuration +preprocessing: + resize: [512, 512] + blur_size: 11 +augmentation: + random_scaling: + enable: True + range: [0.7, 1.5] + photometric: + enable: True + primitives: ['random_brightness', 'random_contrast', + 'additive_speckle_noise', 'additive_gaussian_noise', + 'additive_shade', 'motion_blur' ] + params: + random_brightness: {brightness: 0.2} + random_contrast: {contrast: [0.3, 1.5]} + additive_gaussian_noise: {stddev_range: [0, 10]} + additive_speckle_noise: {prob_range: [0, 0.0035]} + additive_shade: + transparency_range: [-0.5, 0.5] + kernel_size_range: [100, 150] + motion_blur: {max_kernel_size: 3} + random_order: True + homographic: + enable: True + params: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + patch_ratio: 0.85 + max_angle: 1.57 + allow_artifacts: true + valid_border_margin: 3 + +### Homography adaptation configuration +homography_adaptation: + num_iter: 100 + valid_border_margin: 3 + min_counts: 30 + homographies: + translation: true + rotation: true + scaling: true + perspective: true + scaling_amplitude: 0.2 + perspective_amplitude_x: 0.2 + perspective_amplitude_y: 0.2 + allow_artifacts: true + patch_ratio: 0.85 \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/dataset/__init__.py b/imcui/third_party/SOLD2/sold2/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/sold2/dataset/dataset_util.py b/imcui/third_party/SOLD2/sold2/dataset/dataset_util.py new file mode 100644 index 0000000000000000000000000000000000000000..50439ef3e2958d82719da0f6d10f4a7d98322f9a --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/dataset_util.py @@ -0,0 +1,60 @@ +""" +The interface of initializing different datasets. +""" +from .synthetic_dataset import SyntheticShapes +from .wireframe_dataset import WireframeDataset +from .holicity_dataset import HolicityDataset +from .merge_dataset import MergeDataset + + +def get_dataset(mode="train", dataset_cfg=None): + """ Initialize different dataset based on a configuration. """ + # Check dataset config is given + if dataset_cfg is None: + raise ValueError("[Error] The dataset config is required!") + + # Synthetic dataset + if dataset_cfg["dataset_name"] == "synthetic_shape": + dataset = SyntheticShapes( + mode, dataset_cfg + ) + + # Get the collate_fn + from .synthetic_dataset import synthetic_collate_fn + collate_fn = synthetic_collate_fn + + # Wireframe dataset + elif dataset_cfg["dataset_name"] == "wireframe": + dataset = WireframeDataset( + mode, dataset_cfg + ) + + # Get the collate_fn + from .wireframe_dataset import wireframe_collate_fn + collate_fn = wireframe_collate_fn + + # Holicity dataset + elif dataset_cfg["dataset_name"] == "holicity": + dataset = HolicityDataset( + mode, dataset_cfg + ) + + # Get the collate_fn + from .holicity_dataset import holicity_collate_fn + collate_fn = holicity_collate_fn + + # Dataset merging several datasets in one + elif dataset_cfg["dataset_name"] == "merge": + dataset = MergeDataset( + mode, dataset_cfg + ) + + # Get the collate_fn + from .holicity_dataset import holicity_collate_fn + collate_fn = holicity_collate_fn + + else: + raise ValueError( + "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"]) + + return dataset, collate_fn diff --git a/imcui/third_party/SOLD2/sold2/dataset/holicity_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/holicity_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e4437f37bda366983052de902a41467ca01412bd --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/holicity_dataset.py @@ -0,0 +1,797 @@ +""" +File to process and load the Holicity dataset. +""" +import os +import math +import copy +import PIL +import numpy as np +import h5py +import cv2 +import pickle +from skimage.io import imread +from skimage import color +import torch +import torch.utils.data.dataloader as torch_loader +from torch.utils.data import Dataset +from torchvision import transforms + +from ..config.project_config import Config as cfg +from .transforms import photometric_transforms as photoaug +from .transforms import homographic_transforms as homoaug +from .transforms.utils import random_scaling +from .synthetic_util import get_line_heatmap +from ..misc.geometry_utils import warp_points, mask_points +from ..misc.train_utils import parse_h5_data + + +def holicity_collate_fn(batch): + """ Customized collate_fn. """ + batch_keys = ["image", "junction_map", "valid_mask", "heatmap", + "heatmap_pos", "heatmap_neg", "homography", + "line_points", "line_indices"] + list_keys = ["junctions", "line_map", "line_map_pos", + "line_map_neg", "file_key"] + + outputs = {} + for data_key in batch[0].keys(): + batch_match = sum([_ in data_key for _ in batch_keys]) + list_match = sum([_ in data_key for _ in list_keys]) + # print(batch_match, list_match) + if batch_match > 0 and list_match == 0: + outputs[data_key] = torch_loader.default_collate( + [b[data_key] for b in batch]) + elif batch_match == 0 and list_match > 0: + outputs[data_key] = [b[data_key] for b in batch] + elif batch_match == 0 and list_match == 0: + continue + else: + raise ValueError( + "[Error] A key matches batch keys and list keys simultaneously.") + + return outputs + + +class HolicityDataset(Dataset): + def __init__(self, mode="train", config=None): + super(HolicityDataset, self).__init__() + if not mode in ["train", "test"]: + raise ValueError( + "[Error] Unknown mode for Holicity dataset. Only 'train' and 'test'.") + self.mode = mode + + if config is None: + self.config = self.get_default_config() + else: + self.config = config + # Also get the default config + self.default_config = self.get_default_config() + + # Get cache setting + self.dataset_name = self.get_dataset_name() + self.cache_name = self.get_cache_name() + self.cache_path = cfg.holicity_cache_path + + # Get the ground truth source if it exists + self.gt_source = None + if "gt_source_%s"%(self.mode) in self.config: + self.gt_source = self.config.get("gt_source_%s"%(self.mode)) + self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source) + # Check the full path exists + if not os.path.exists(self.gt_source): + raise ValueError( + "[Error] The specified ground truth source does not exist.") + + # Get the filename dataset + print("[Info] Initializing Holicity dataset...") + self.filename_dataset, self.datapoints = self.construct_dataset() + + # Get dataset length + self.dataset_length = len(self.datapoints) + + # Print some info + print("[Info] Successfully initialized dataset") + print("\t Name: Holicity") + print("\t Mode: %s" %(self.mode)) + print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode), + "None"))) + print("\t Counts: %d" %(self.dataset_length)) + print("----------------------------------------") + + ####################################### + ## Dataset construction related APIs ## + ####################################### + def construct_dataset(self): + """ Construct the dataset (from scratch or from cache). """ + # Check if the filename cache exists + # If cache exists, load from cache + if self.check_dataset_cache(): + print("\t Found filename cache %s at %s"%(self.cache_name, + self.cache_path)) + print("\t Load filename cache...") + filename_dataset, datapoints = self.get_filename_dataset_from_cache() + # If not, initialize dataset from scratch + else: + print("\t Can't find filename cache ...") + print("\t Create filename dataset from scratch...") + filename_dataset, datapoints = self.get_filename_dataset() + print("\t Create filename dataset cache...") + self.create_filename_dataset_cache(filename_dataset, datapoints) + + return filename_dataset, datapoints + + def create_filename_dataset_cache(self, filename_dataset, datapoints): + """ Create filename dataset cache for faster initialization. """ + # Check cache path exists + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + cache_file_path = os.path.join(self.cache_path, self.cache_name) + data = { + "filename_dataset": filename_dataset, + "datapoints": datapoints + } + with open(cache_file_path, "wb") as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + def get_filename_dataset_from_cache(self): + """ Get filename dataset from cache. """ + # Load from pkl cache + cache_file_path = os.path.join(self.cache_path, self.cache_name) + with open(cache_file_path, "rb") as f: + data = pickle.load(f) + + return data["filename_dataset"], data["datapoints"] + + def get_filename_dataset(self): + """ Get the path to the dataset. """ + if self.mode == "train": + # Contains 5720 or 11872 images + dataset_path = [os.path.join(cfg.holicity_dataroot, p) + for p in self.config["train_splits"]] + else: + # Test mode - Contains 520 images + dataset_path = [os.path.join(cfg.holicity_dataroot, "2018-03")] + + # Get paths to all image files + image_paths = [] + for folder in dataset_path: + image_paths += [os.path.join(folder, img) + for img in os.listdir(folder) + if os.path.splitext(img)[-1] == ".jpg"] + image_paths = sorted(image_paths) + + # Verify all the images exist + for idx in range(len(image_paths)): + image_path = image_paths[idx] + if not (os.path.exists(image_path)): + raise ValueError( + "[Error] The image does not exist. %s"%(image_path)) + + # Construct the filename dataset + num_pad = int(math.ceil(math.log10(len(image_paths))) + 1) + filename_dataset = {} + for idx in range(len(image_paths)): + # Get the file key + key = self.get_padded_filename(num_pad, idx) + + filename_dataset[key] = {"image": image_paths[idx]} + + # Get the datapoints + datapoints = list(sorted(filename_dataset.keys())) + + return filename_dataset, datapoints + + def get_dataset_name(self): + """ Get dataset name from dataset config / default config. """ + dataset_name = self.config.get("dataset_name", + self.default_config["dataset_name"]) + dataset_name = dataset_name + "_%s" % self.mode + return dataset_name + + def get_cache_name(self): + """ Get cache name from dataset config / default config. """ + dataset_name = self.config.get("dataset_name", + self.default_config["dataset_name"]) + dataset_name = dataset_name + "_%s" % self.mode + # Compose cache name + cache_name = dataset_name + "_cache.pkl" + return cache_name + + def check_dataset_cache(self): + """ Check if dataset cache exists. """ + cache_file_path = os.path.join(self.cache_path, self.cache_name) + if os.path.exists(cache_file_path): + return True + else: + return False + + @staticmethod + def get_padded_filename(num_pad, idx): + """ Get the padded filename using adaptive padding. """ + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + return filename + + def get_default_config(self): + """ Get the default configuration. """ + return { + "dataset_name": "holicity", + "train_split": "2018-01", + "add_augmentation_to_all_splits": False, + "preprocessing": { + "resize": [512, 512], + "blur_size": 11 + }, + "augmentation":{ + "photometric":{ + "enable": False + }, + "homographic":{ + "enable": False + }, + }, + } + + ############################################ + ## Pytorch and preprocessing related APIs ## + ############################################ + @staticmethod + def get_data_from_path(data_path): + """ Get data from the information from filename dataset. """ + output = {} + + # Get image data + image_path = data_path["image"] + image = imread(image_path) + output["image"] = image + + return output + + @staticmethod + def convert_line_map(lcnn_line_map, num_junctions): + """ Convert the line_pos or line_neg + (represented by two junction indexes) to our line map. """ + # Initialize empty line map + line_map = np.zeros([num_junctions, num_junctions]) + + # Iterate through all the lines + for idx in range(lcnn_line_map.shape[0]): + index1 = lcnn_line_map[idx, 0] + index2 = lcnn_line_map[idx, 1] + + line_map[index1, index2] = 1 + line_map[index2, index1] = 1 + + return line_map + + @staticmethod + def junc_to_junc_map(junctions, image_size): + """ Convert junction points to junction maps. """ + junctions = np.round(junctions).astype(np.int) + # Clip the boundary by image size + junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) + junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + + # Create junction map + junc_map = np.zeros([image_size[0], image_size[1]]) + junc_map[junctions[:, 0], junctions[:, 1]] = 1 + + return junc_map[..., None].astype(np.int) + + def parse_transforms(self, names, all_transforms): + """ Parse the transform. """ + trans = all_transforms if (names == 'all') \ + else (names if isinstance(names, list) else [names]) + assert set(trans) <= set(all_transforms) + return trans + + def get_photo_transform(self): + """ Get list of photometric transforms (according to the config). """ + # Get the photometric transform config + photo_config = self.config["augmentation"]["photometric"] + if not photo_config["enable"]: + raise ValueError( + "[Error] Photometric augmentation is not enabled.") + + # Parse photometric transforms + trans_lst = self.parse_transforms(photo_config["primitives"], + photoaug.available_augmentations) + trans_config_lst = [photo_config["params"].get(p, {}) + for p in trans_lst] + + # List of photometric augmentation + photometric_trans_lst = [ + getattr(photoaug, trans)(**conf) \ + for (trans, conf) in zip(trans_lst, trans_config_lst) + ] + + return photometric_trans_lst + + def get_homo_transform(self): + """ Get homographic transforms (according to the config). """ + # Get homographic transforms for image + homo_config = self.config["augmentation"]["homographic"]["params"] + if not self.config["augmentation"]["homographic"]["enable"]: + raise ValueError( + "[Error] Homographic augmentation is not enabled") + + # Parse the homographic transforms + image_shape = self.config["preprocessing"]["resize"] + + # Compute the min_label_len from config + try: + min_label_tmp = self.config["generation"]["min_label_len"] + except: + min_label_tmp = None + + # float label len => fraction + if isinstance(min_label_tmp, float): # Skip if not provided + min_label_len = min_label_tmp * min(image_shape) + # int label len => length in pixel + elif isinstance(min_label_tmp, int): + scale_ratio = (self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0]) + min_label_len = (self.config["generation"]["min_label_len"] + * scale_ratio) + # if none => no restriction + else: + min_label_len = 0 + + # Initialize the transform + homographic_trans = homoaug.homography_transform( + image_shape, homo_config, 0, min_label_len) + + return homographic_trans + + def get_line_points(self, junctions, line_map, H1=None, H2=None, + img_size=None, warp=False): + """ Sample evenly points along each line segments + and keep track of line idx. """ + if np.sum(line_map) == 0: + # No segment detected in the image + line_indices = np.zeros(self.config["max_pts"], dtype=int) + line_points = np.zeros((self.config["max_pts"], 2), dtype=float) + return line_points, line_indices + + # Extract all pairs of connected junctions + junc_indices = np.array( + [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i]) + line_segments = np.stack([junctions[junc_indices[:, 0]], + junctions[junc_indices[:, 1]]], axis=1) + # line_segments is (num_lines, 2, 2) + line_lengths = np.linalg.norm( + line_segments[:, 0] - line_segments[:, 1], axis=1) + + # Sample the points separated by at least min_dist_pts along each line + # The number of samples depends on the length of the line + num_samples = np.minimum(line_lengths // self.config["min_dist_pts"], + self.config["max_num_samples"]) + line_points = [] + line_indices = [] + cur_line_idx = 1 + for n in np.arange(2, self.config["max_num_samples"] + 1): + # Consider all lines where we can fit up to n points + cur_line_seg = line_segments[num_samples == n] + line_points_x = np.linspace(cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 0], + n, axis=-1).flatten() + line_points_y = np.linspace(cur_line_seg[:, 0, 1], + cur_line_seg[:, 1, 1], + n, axis=-1).flatten() + jitter = self.config.get("jittering", 0) + if jitter: + # Add a small random jittering of all points along the line + angles = np.arctan2( + cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n) + jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter + line_points_x += jitter_hyp * np.sin(angles) + line_points_y += jitter_hyp * np.cos(angles) + line_points.append(np.stack([line_points_x, line_points_y], axis=-1)) + # Keep track of the line indices for each sampled point + num_cur_lines = len(cur_line_seg) + line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines) + line_indices.append(line_idx.repeat(n)) + cur_line_idx += num_cur_lines + line_points = np.concatenate(line_points, + axis=0)[:self.config["max_pts"]] + line_indices = np.concatenate(line_indices, + axis=0)[:self.config["max_pts"]] + + # Warp the points if need be, and filter unvalid ones + # If the other view is also warped + if warp and H2 is not None: + warp_points2 = warp_points(line_points, H2) + line_points = warp_points(line_points, H1) + mask = mask_points(line_points, img_size) + mask2 = mask_points(warp_points2, img_size) + mask = mask * mask2 + # If the other view is not warped + elif warp and H2 is None: + line_points = warp_points(line_points, H1) + mask = mask_points(line_points, img_size) + else: + if H1 is not None: + raise ValueError("[Error] Wrong combination of homographies.") + # Remove points that would be outside of img_size if warped by H + warped_points = warp_points(line_points, H1) + mask = mask_points(warped_points, img_size) + line_points = line_points[mask] + line_indices = line_indices[mask] + + # Pad the line points to a fixed length + # Index of 0 means padded line + line_indices = np.concatenate([line_indices, np.zeros( + self.config["max_pts"] - len(line_indices))], axis=0) + line_points = np.concatenate( + [line_points, + np.zeros((self.config["max_pts"] - len(line_points), 2), + dtype=float)], axis=0) + + return line_points, line_indices + + def export_preprocessing(self, data, numpy=False): + """ Preprocess the exported data. """ + # Fetch the corresponding entries + image = data["image"] + image_size = image.shape[:2] + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + image = photoaug.normalize_image()(image) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + return {"image": to_tensor(image)} + else: + return {"image": image} + + def train_preprocessing_exported( + self, data, numpy=False, disable_homoaug=False, desc_training=False, + H1=None, H1_scale=None, H2=None, scale=1., h_crop=None, w_crop=None): + """ Train preprocessing for the exported labels. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junctions"] + line_map = data["line_map"] + image_size = image.shape[:2] + + # Define the random crop for scaling if necessary + if h_crop is None or w_crop is None: + h_crop, w_crop = 0, 0 + if scale > 1: + H, W = self.config["preprocessing"]["resize"] + H_scale, W_scale = round(H * scale), round(W * scale) + if H_scale > H: + h_crop = np.random.randint(H_scale - H) + if W_scale > W: + w_crop = np.random.randint(W_scale - W) + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # # In HW format + # junctions = (junctions * np.array( + # self.config['preprocessing']['resize'], np.float) + # / np.array(size_old, np.float)) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + image_size = image.shape[:2] + heatmap = get_line_heatmap(junctions_xy, line_map, image_size) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Check if we need to apply augmentations + # In training mode => yes. + # In homography adaptation mode (export mode) => No + if self.config["augmentation"]["photometric"]["enable"]: + photo_trans_lst = self.get_photo_transform() + ### Image transform ### + np.random.shuffle(photo_trans_lst) + image_transform = transforms.Compose( + photo_trans_lst + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Perform the random scaling + if scale != 1.: + image, junctions, line_map, valid_mask = random_scaling( + image, junctions, line_map, scale, + h_crop=h_crop, w_crop=w_crop) + else: + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + # Initialize the empty output dict + outputs = {} + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + + # Check homographic augmentation + warp = (self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False) + if warp: + homo_trans = self.get_homo_transform() + # Perform homographic transform + if H1 is None: + homo_outputs = homo_trans(image, junctions, line_map, + valid_mask=valid_mask) + else: + homo_outputs = homo_trans( + image, junctions, line_map, homo=H1, scale=H1_scale, + valid_mask=valid_mask) + homography_mat = homo_outputs["homo"] + + # Give the warp of the other view + if H1 is None: + H1 = homo_outputs["homo"] + + # Sample points along each line segments for the descriptor + if desc_training: + line_points, line_indices = self.get_line_points( + junctions, line_map, H1=H1, H2=H2, + img_size=image_size, warp=warp) + + # Record the warped results + if warp: + junctions = homo_outputs["junctions"] # Should be HW format + image = homo_outputs["warped_image"] + line_map = homo_outputs["line_map"] + valid_mask = homo_outputs["valid_mask"] # Same for pos and neg + heatmap = homo_outputs["warped_heatmap"] + + # Optionally put warping information first. + if not numpy: + outputs["homography_mat"] = to_tensor( + homography_mat).to(torch.float32)[0, ...] + else: + outputs["homography_mat"] = homography_mat.astype(np.float32) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + if not numpy: + outputs.update({ + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + }) + if desc_training: + outputs.update({ + "line_points": to_tensor( + line_points).to(torch.float32)[0], + "line_indices": torch.tensor(line_indices, + dtype=torch.int) + }) + else: + outputs.update({ + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + }) + if desc_training: + outputs.update({ + "line_points": line_points.astype(np.float32), + "line_indices": line_indices.astype(int) + }) + + return outputs + + def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.): + """ Train preprocessing for paired data for the exported labels + for descriptor training. """ + outputs = {} + + # Define the random crop for scaling if necessary + h_crop, w_crop = 0, 0 + if scale > 1: + H, W = self.config["preprocessing"]["resize"] + H_scale, W_scale = round(H * scale), round(W * scale) + if H_scale > H: + h_crop = np.random.randint(H_scale - H) + if W_scale > W: + w_crop = np.random.randint(W_scale - W) + + # Sample ref homography first + homo_config = self.config["augmentation"]["homographic"]["params"] + image_shape = self.config["preprocessing"]["resize"] + ref_H, ref_scale = homoaug.sample_homography(image_shape, + **homo_config) + + # Data for target view (All augmentation) + target_data = self.train_preprocessing_exported( + data, numpy=numpy, desc_training=True, H1=None, H2=ref_H, + scale=scale, h_crop=h_crop, w_crop=w_crop) + + # Data for reference view (No homographical augmentation) + ref_data = self.train_preprocessing_exported( + data, numpy=numpy, desc_training=True, H1=ref_H, + H1_scale=ref_scale, H2=target_data['homography_mat'].numpy(), + scale=scale, h_crop=h_crop, w_crop=w_crop) + + # Spread ref data + for key, val in ref_data.items(): + outputs["ref_" + key] = val + + # Spread target data + for key, val in target_data.items(): + outputs["target_" + key] = val + + return outputs + + def test_preprocessing_exported(self, data, numpy=False): + """ Test preprocessing for the exported labels. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junctions"] + line_map = data["line_map"] + image_size = image.shape[:2] + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # # In HW format + # junctions = (junctions * np.array( + # self.config['preprocessing']['resize'], np.float) + # / np.array(size_old, np.float)) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Still need to normalize image + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + image_size = image.shape[:2] + heatmap = get_line_heatmap(junctions_xy, line_map, image_size) + + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + outputs = { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + outputs = { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + } + + return outputs + + def __len__(self): + return self.dataset_length + + def get_data_from_key(self, file_key): + """ Get data from file_key. """ + # Check key exists + if not file_key in self.filename_dataset.keys(): + raise ValueError( + "[Error] the specified key is not in the dataset.") + + # Get the data paths + data_path = self.filename_dataset[file_key] + # Read in the image and npz labels + data = self.get_data_from_path(data_path) + + # Perform transform and augmentation + if (self.mode == "train" + or self.config["add_augmentation_to_all_splits"]): + data = self.train_preprocessing(data, numpy=True) + else: + data = self.test_preprocessing(data, numpy=True) + + # Add file key to the output + data["file_key"] = file_key + + return data + + def __getitem__(self, idx): + """Return data + file_key: str, keys used to retrieve data from the filename dataset. + image: torch.float, C*H*W range 0~1, + junctions: torch.float, N*2, + junction_map: torch.int32, 1*H*W range 0 or 1, + line_map: torch.int32, N*N range 0 or 1, + heatmap: torch.int32, 1*H*W range 0 or 1, + valid_mask: torch.int32, 1*H*W range 0 or 1 + """ + # Get the corresponding datapoint and contents from filename dataset + file_key = self.datapoints[idx] + data_path = self.filename_dataset[file_key] + # Read in the image and npz labels + data = self.get_data_from_path(data_path) + + if self.gt_source: + with h5py.File(self.gt_source, "r") as f: + exported_label = parse_h5_data(f[file_key]) + + data["junctions"] = exported_label["junctions"] + data["line_map"] = exported_label["line_map"] + + # Perform transform and augmentation + return_type = self.config.get("return_type", "single") + if self.gt_source is None: + # For export only + data = self.export_preprocessing(data) + elif (self.mode == "train" + or self.config["add_augmentation_to_all_splits"]): + # Perform random scaling first + if self.config["augmentation"]["random_scaling"]["enable"]: + scale_range = self.config["augmentation"]["random_scaling"]["range"] + # Decide the scaling + scale = np.random.uniform(min(scale_range), max(scale_range)) + else: + scale = 1. + if self.mode == "train" and return_type == "paired_desc": + data = self.preprocessing_exported_paired_desc(data, + scale=scale) + else: + data = self.train_preprocessing_exported(data, scale=scale) + else: + if return_type == "paired_desc": + data = self.preprocessing_exported_paired_desc(data) + else: + data = self.test_preprocessing_exported(data) + + # Add file key to the output + data["file_key"] = file_key + + return data + diff --git a/imcui/third_party/SOLD2/sold2/dataset/merge_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/merge_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..178d3822d56639a49a99f68e392330e388fa8fc3 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/merge_dataset.py @@ -0,0 +1,37 @@ +""" Compose multiple datasets in a single loader. """ + +import numpy as np +from copy import deepcopy +from torch.utils.data import Dataset + +from .wireframe_dataset import WireframeDataset +from .holicity_dataset import HolicityDataset + + +class MergeDataset(Dataset): + def __init__(self, mode, config=None): + super(MergeDataset, self).__init__() + # Initialize the datasets + self._datasets = [] + spec_config = deepcopy(config) + for i, d in enumerate(config['datasets']): + spec_config['dataset_name'] = d + spec_config['gt_source_train'] = config['gt_source_train'][i] + spec_config['gt_source_test'] = config['gt_source_test'][i] + if d == "wireframe": + self._datasets.append(WireframeDataset(mode, spec_config)) + elif d == "holicity": + spec_config['train_split'] = config['train_splits'][i] + self._datasets.append(HolicityDataset(mode, spec_config)) + else: + raise ValueError("Unknown dataset: " + d) + + self._weights = config['weights'] + + def __getitem__(self, item): + dataset = self._datasets[np.random.choice( + range(len(self._datasets)), p=self._weights)] + return dataset[np.random.randint(len(dataset))] + + def __len__(self): + return np.sum([len(d) for d in self._datasets]) diff --git a/imcui/third_party/SOLD2/sold2/dataset/synthetic_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/synthetic_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cf5f11e5407e65887f4995291156f7cc361843d1 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/synthetic_dataset.py @@ -0,0 +1,712 @@ +""" +This file implements the synthetic shape dataset object for pytorch +""" +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import + +import os +import math +import h5py +import pickle +import torch +import numpy as np +import cv2 +from tqdm import tqdm +from torchvision import transforms +from torch.utils.data import Dataset +import torch.utils.data.dataloader as torch_loader + +from ..config.project_config import Config as cfg +from . import synthetic_util +from .transforms import photometric_transforms as photoaug +from .transforms import homographic_transforms as homoaug +from ..misc.train_utils import parse_h5_data + + +def synthetic_collate_fn(batch): + """ Customized collate_fn. """ + batch_keys = ["image", "junction_map", "heatmap", + "valid_mask", "homography"] + list_keys = ["junctions", "line_map", "file_key"] + + outputs = {} + for data_key in batch[0].keys(): + batch_match = sum([_ in data_key for _ in batch_keys]) + list_match = sum([_ in data_key for _ in list_keys]) + # print(batch_match, list_match) + if batch_match > 0 and list_match == 0: + outputs[data_key] = torch_loader.default_collate([b[data_key] + for b in batch]) + elif batch_match == 0 and list_match > 0: + outputs[data_key] = [b[data_key] for b in batch] + elif batch_match == 0 and list_match == 0: + continue + else: + raise ValueError( + "[Error] A key matches batch keys and list keys simultaneously.") + + return outputs + + +class SyntheticShapes(Dataset): + """ Dataset of synthetic shapes. """ + # Initialize the dataset + def __init__(self, mode="train", config=None): + super(SyntheticShapes, self).__init__() + if not mode in ["train", "val", "test"]: + raise ValueError( + "[Error] Supported dataset modes are 'train', 'val', and 'test'.") + self.mode = mode + + # Get configuration + if config is None: + self.config = self.get_default_config() + else: + self.config = config + + # Set all available primitives + self.available_primitives = [ + 'draw_lines', + 'draw_polygon', + 'draw_multiple_polygons', + 'draw_star', + 'draw_checkerboard_multiseg', + 'draw_stripes_multiseg', + 'draw_cube', + 'gaussian_noise' + ] + + # Some cache setting + self.dataset_name = self.get_dataset_name() + self.cache_name = self.get_cache_name() + self.cache_path = cfg.synthetic_cache_path + + # Check if export dataset exists + print("===============================================") + self.filename_dataset, self.datapoints = self.construct_dataset() + self.print_dataset_info() + + # Initialize h5 file handle + self.dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5") + + # Fix the random seed for torch and numpy in testing mode + if ((self.mode == "val" or self.mode == "test") + and self.config["add_augmentation_to_all_splits"]): + seed = self.config.get("test_augmentation_seed", 200) + np.random.seed(seed) + torch.manual_seed(seed) + # For CuDNN + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + ########################################## + ## Dataset construction related methods ## + ########################################## + def construct_dataset(self): + """ Dataset constructor. """ + # Check if the filename cache exists + # If cache exists, load from cache + if self._check_dataset_cache(): + print("[Info]: Found filename cache at ...") + print("\t Load filename cache...") + filename_dataset, datapoints = self.get_filename_dataset_from_cache() + print("\t Check if all file exists...") + # If all file exists, continue + if self._check_file_existence(filename_dataset): + print("\t All files exist!") + # If not, need to re-export the synthetic dataset + else: + print("\t Some files are missing. Re-export the synthetic shape dataset.") + self.export_synthetic_shapes() + print("\t Initialize filename dataset") + filename_dataset, datapoints = self.get_filename_dataset() + print("\t Create filename dataset cache...") + self.create_filename_dataset_cache(filename_dataset, + datapoints) + + # If not, initialize dataset from scratch + else: + print("[Info]: Can't find filename cache ...") + print("\t First check export dataset exists.") + # If export dataset exists, then just update the filename_dataset + if self._check_export_dataset(): + print("\t Synthetic dataset exists. Initialize the dataset ...") + + # If export dataset does not exist, export from scratch + else: + print("\t Synthetic dataset does not exist. Export the synthetic dataset.") + self.export_synthetic_shapes() + print("\t Initialize filename dataset") + + filename_dataset, datapoints = self.get_filename_dataset() + print("\t Create filename dataset cache...") + self.create_filename_dataset_cache(filename_dataset, datapoints) + + return filename_dataset, datapoints + + def get_cache_name(self): + """ Get cache name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + # Compose cache name + cache_name = dataset_name + "_cache.pkl" + + return cache_name + + def get_dataset_name(self): + """Get dataset name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + + return dataset_name + + def get_filename_dataset_from_cache(self): + """ Get filename dataset from cache. """ + # Load from the pkl cache + cache_file_path = os.path.join(self.cache_path, self.cache_name) + with open(cache_file_path, "rb") as f: + data = pickle.load(f) + + return data["filename_dataset"], data["datapoints"] + + def get_filename_dataset(self): + """ Get filename dataset from scratch. """ + # Path to the exported dataset + dataset_path = os.path.join(cfg.synthetic_dataroot, + self.dataset_name + ".h5") + + filename_dataset = {} + datapoints = [] + # Open the h5 dataset + with h5py.File(dataset_path, "r") as f: + # Iterate through all the primitives + for prim_name in f.keys(): + filenames = sorted(f[prim_name].keys()) + filenames_full = [os.path.join(prim_name, _) + for _ in filenames] + + filename_dataset[prim_name] = filenames_full + datapoints += filenames_full + + return filename_dataset, datapoints + + def create_filename_dataset_cache(self, filename_dataset, datapoints): + """ Create filename dataset cache for faster initialization. """ + # Check cache path exists + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + cache_file_path = os.path.join(self.cache_path, self.cache_name) + data = { + "filename_dataset": filename_dataset, + "datapoints": datapoints + } + with open(cache_file_path, "wb") as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + def export_synthetic_shapes(self): + """ Export synthetic shapes to disk. """ + # Set the global random state for data generation + synthetic_util.set_random_state(np.random.RandomState( + self.config["generation"]["random_seed"])) + + # Define the export path + dataset_path = os.path.join(cfg.synthetic_dataroot, + self.dataset_name + ".h5") + + # Open h5py file + with h5py.File(dataset_path, "w", libver="latest") as f: + # Iterate through all types of shape + primitives = self.parse_drawing_primitives( + self.config["primitives"]) + split_size = self.config["generation"]["split_sizes"][self.mode] + for prim in primitives: + # Create h5 group + group = f.create_group(prim) + # Export single primitive + self.export_single_primitive(prim, split_size, group) + + f.swmr_mode = True + + def export_single_primitive(self, primitive, split_size, group): + """ Export single primitive. """ + # Check if the primitive is valid or not + if primitive not in self.available_primitives: + raise ValueError( + "[Error]: %s is not a supported primitive" % primitive) + # Set the random seed + synthetic_util.set_random_state(np.random.RandomState( + self.config["generation"]["random_seed"])) + + # Generate shapes + print("\t Generating %s ..." % primitive) + for idx in tqdm(range(split_size), ascii=True): + # Generate background image + image = synthetic_util.generate_background( + self.config['generation']['image_size'], + **self.config['generation']['params']['generate_background']) + + # Generate points + drawing_func = getattr(synthetic_util, primitive) + kwarg = self.config["generation"]["params"].get(primitive, {}) + + # Get min_len and min_label_len + min_len = self.config["generation"]["min_len"] + min_label_len = self.config["generation"]["min_label_len"] + + # Some only take min_label_len, and gaussian noises take nothing + if primitive in ["draw_lines", "draw_polygon", + "draw_multiple_polygons", "draw_star"]: + data = drawing_func(image, min_len=min_len, + min_label_len=min_label_len, **kwarg) + elif primitive in ["draw_checkerboard_multiseg", + "draw_stripes_multiseg", "draw_cube"]: + data = drawing_func(image, min_label_len=min_label_len, + **kwarg) + else: + data = drawing_func(image, **kwarg) + + # Convert the data + if data["points"] is not None: + points = np.flip(data["points"], axis=1).astype(np.float) + line_map = data["line_map"].astype(np.int32) + else: + points = np.zeros([0, 2]).astype(np.float) + line_map = np.zeros([0, 0]).astype(np.int32) + + # Post-processing + blur_size = self.config["preprocessing"]["blur_size"] + image = cv2.GaussianBlur(image, (blur_size, blur_size), 0) + + # Resize the image and the point location. + points = (points + * np.array(self.config['preprocessing']['resize'], + np.float) + / np.array(self.config['generation']['image_size'], + np.float)) + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # Generate the line heatmap after post-processing + junctions = np.flip(np.round(points).astype(np.int32), axis=1) + heatmap = (synthetic_util.get_line_heatmap( + junctions, line_map, + size=image.shape) * 255.).astype(np.uint8) + + # Record the data in group + num_pad = math.ceil(math.log10(split_size)) + 1 + file_key_name = self.get_padded_filename(num_pad, idx) + file_group = group.create_group(file_key_name) + + # Store data + file_group.create_dataset("points", data=points, + compression="gzip") + file_group.create_dataset("image", data=image, + compression="gzip") + file_group.create_dataset("line_map", data=line_map, + compression="gzip") + file_group.create_dataset("heatmap", data=heatmap, + compression="gzip") + + def get_default_config(self): + """ Get default configuration of the dataset. """ + # Initialize the default configuration + self.default_config = { + "dataset_name": "synthetic_shape", + "primitives": "all", + "add_augmentation_to_all_splits": False, + # Shape generation configuration + "generation": { + "split_sizes": {'train': 10000, 'val': 400, 'test': 500}, + "random_seed": 10, + "image_size": [960, 1280], + "min_len": 0.09, + "min_label_len": 0.1, + 'params': { + 'generate_background': { + 'min_kernel_size': 150, 'max_kernel_size': 500, + 'min_rad_ratio': 0.02, 'max_rad_ratio': 0.031}, + 'draw_stripes': {'transform_params': (0.1, 0.1)}, + 'draw_multiple_polygons': {'kernel_boundaries': (50, 100)} + }, + }, + # Date preprocessing configuration. + "preprocessing": { + "resize": [240, 320], + "blur_size": 11 + }, + 'augmentation': { + 'photometric': { + 'enable': False, + 'primitives': 'all', + 'params': {}, + 'random_order': True, + }, + 'homographic': { + 'enable': False, + 'params': {}, + 'valid_border_margin': 0, + }, + } + } + + return self.default_config + + def parse_drawing_primitives(self, names): + """ Parse the primitives in config to list of primitive names. """ + if names == "all": + p = self.available_primitives + else: + if isinstance(names, list): + p = names + else: + p = [names] + + assert set(p) <= set(self.available_primitives) + + return p + + @staticmethod + def get_padded_filename(num_pad, idx): + """ Get the padded filename using adaptive padding. """ + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + + return filename + + def print_dataset_info(self): + """ Print dataset info. """ + print("\t ---------Summary------------------") + print("\t Dataset mode: \t\t %s" % self.mode) + print("\t Number of primitive: \t %d" % len(self.filename_dataset.keys())) + print("\t Number of data: \t %d" % len(self.datapoints)) + print("\t ----------------------------------") + + ######################### + ## Pytorch related API ## + ######################### + def get_data_from_datapoint(self, datapoint, reader=None): + """ Get data given the datapoint + (keyname of the h5 dataset e.g. "draw_lines/0000.h5"). """ + # Check if the datapoint is valid + if not datapoint in self.datapoints: + raise ValueError( + "[Error] The specified datapoint is not in available datapoints.") + + # Get data from h5 dataset + if reader is None: + raise ValueError( + "[Error] The reader must be provided in __getitem__.") + else: + data = reader[datapoint] + + return parse_h5_data(data) + + def get_data_from_signature(self, primitive_name, index): + """ Get data given the primitive name and index ("draw_lines", 10) """ + # Check the primitive name and index + self._check_primitive_and_index(primitive_name, index) + + # Get the datapoint from filename dataset + datapoint = self.filename_dataset[primitive_name][index] + + return self.get_data_from_datapoint(datapoint) + + def parse_transforms(self, names, all_transforms): + trans = all_transforms if (names == 'all') \ + else (names if isinstance(names, list) else [names]) + assert set(trans) <= set(all_transforms) + return trans + + def get_photo_transform(self): + """ Get list of photometric transforms (according to the config). """ + # Get the photometric transform config + photo_config = self.config["augmentation"]["photometric"] + if not photo_config["enable"]: + raise ValueError( + "[Error] Photometric augmentation is not enabled.") + + # Parse photometric transforms + trans_lst = self.parse_transforms(photo_config["primitives"], + photoaug.available_augmentations) + trans_config_lst = [photo_config["params"].get(p, {}) + for p in trans_lst] + + # List of photometric augmentation + photometric_trans_lst = [ + getattr(photoaug, trans)(**conf) \ + for (trans, conf) in zip(trans_lst, trans_config_lst) + ] + + return photometric_trans_lst + + def get_homo_transform(self): + """ Get homographic transforms (according to the config). """ + # Get homographic transforms for image + homo_config = self.config["augmentation"]["homographic"]["params"] + if not self.config["augmentation"]["homographic"]["enable"]: + raise ValueError( + "[Error] Homographic augmentation is not enabled") + + # Parse the homographic transforms + # ToDo: use the shape from the config + image_shape = self.config["preprocessing"]["resize"] + + # Compute the min_label_len from config + try: + min_label_tmp = self.config["generation"]["min_label_len"] + except: + min_label_tmp = None + + # float label len => fraction + if isinstance(min_label_tmp, float): # Skip if not provided + min_label_len = min_label_tmp * min(image_shape) + # int label len => length in pixel + elif isinstance(min_label_tmp, int): + scale_ratio = (self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0]) + min_label_len = (self.config["generation"]["min_label_len"] + * scale_ratio) + # if none => no restriction + else: + min_label_len = 0 + + # Initialize the transform + homographic_trans = homoaug.homography_transform( + image_shape, homo_config, 0, min_label_len) + + return homographic_trans + + @staticmethod + def junc_to_junc_map(junctions, image_size): + """ Convert junction points to junction maps. """ + junctions = np.round(junctions).astype(np.int) + # Clip the boundary by image size + junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) + junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + + # Create junction map + junc_map = np.zeros([image_size[0], image_size[1]]) + junc_map[junctions[:, 0], junctions[:, 1]] = 1 + + return junc_map[..., None].astype(np.int) + + def train_preprocessing(self, data, disable_homoaug=False): + """ Training preprocessing. """ + # Fetch corresponding entries + image = data["image"] + junctions = data["points"] + line_map = data["line_map"] + heatmap = data["heatmap"] + image_size = image.shape[:2] + + # Resize the image before the photometric and homographic transforms + # Check if we need to do the resizing + if not(list(image.shape) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape) + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + junctions = ( + junctions + * np.array(self.config['preprocessing']['resize'], np.float) + / np.array(size_old, np.float)) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), + axis=1) + heatmap = synthetic_util.get_line_heatmap(junctions_xy, line_map, + size=image.shape) + heatmap = (heatmap * 255.).astype(np.uint8) + + # Update image size + image_size = image.shape[:2] + + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + # Check if we need to apply augmentations + # In training mode => yes. + # In homography adaptation mode (export mode) => No + # Check photometric augmentation + if self.config["augmentation"]["photometric"]["enable"]: + photo_trans_lst = self.get_photo_transform() + ### Image transform ### + np.random.shuffle(photo_trans_lst) + image_transform = transforms.Compose( + photo_trans_lst + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Initialize the empty output dict + outputs = {} + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + # Check homographic augmentation + if (self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False): + homo_trans = self.get_homo_transform() + # Perform homographic transform + homo_outputs = homo_trans(image, junctions, line_map) + + # Record the warped results + junctions = homo_outputs["junctions"] # Should be HW format + image = homo_outputs["warped_image"] + line_map = homo_outputs["line_map"] + heatmap = homo_outputs["warped_heatmap"] + valid_mask = homo_outputs["valid_mask"] # Same for pos and neg + homography_mat = homo_outputs["homo"] + + # Optionally put warpping information first. + outputs["homography_mat"] = to_tensor( + homography_mat).to(torch.float32)[0, ...] + + junction_map = self.junc_to_junc_map(junctions, image_size) + + outputs.update({ + "image": to_tensor(image), + "junctions": to_tensor(np.ascontiguousarray( + junctions).copy()).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32), + }) + + return outputs + + def test_preprocessing(self, data): + """ Test preprocessing. """ + # Fetch corresponding entries + image = data["image"] + points = data["points"] + line_map = data["line_map"] + heatmap = data["heatmap"] + image_size = image.shape[:2] + + # Resize the image before the photometric and homographic transforms + if not (list(image.shape) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape) + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + points = (points + * np.array(self.config['preprocessing']['resize'], + np.float) + / np.array(size_old, np.float)) + + # Generate the line heatmap after post-processing + junctions = np.flip(np.round(points).astype(np.int32), axis=1) + heatmap = synthetic_util.get_line_heatmap(junctions, line_map, + size=image.shape) + heatmap = (heatmap * 255.).astype(np.uint8) + + # Update image size + image_size = image.shape[:2] + + ### image transform ### + image_transform = photoaug.normalize_image() + image = image_transform(image) + + ### joint transform ### + junction_map = self.junc_to_junc_map(points, image_size) + to_tensor = transforms.ToTensor() + image = to_tensor(image) + junctions = to_tensor(points) + junction_map = to_tensor(junction_map).to(torch.int) + line_map = to_tensor(line_map) + heatmap = to_tensor(heatmap) + valid_mask = to_tensor(np.ones(image_size)).to(torch.int32) + + return { + "image": image, + "junctions": junctions, + "junction_map": junction_map, + "line_map": line_map, + "heatmap": heatmap, + "valid_mask": valid_mask + } + + def __getitem__(self, index): + datapoint = self.datapoints[index] + + # Initialize reader and use it + with h5py.File(self.dataset_path, "r", swmr=True) as reader: + data = self.get_data_from_datapoint(datapoint, reader) + + # Apply different transforms in different mod. + if (self.mode == "train" + or self.config["add_augmentation_to_all_splits"]): + return_type = self.config.get("return_type", "single") + data = self.train_preprocessing(data) + else: + data = self.test_preprocessing(data) + + return data + + def __len__(self): + return len(self.datapoints) + + ######################## + ## Some other methods ## + ######################## + def _check_dataset_cache(self): + """ Check if dataset cache exists. """ + cache_file_path = os.path.join(self.cache_path, self.cache_name) + if os.path.exists(cache_file_path): + return True + else: + return False + + def _check_export_dataset(self): + """ Check if exported dataset exists. """ + dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name) + if os.path.exists(dataset_path) and len(os.listdir(dataset_path)) > 0: + return True + else: + return False + + def _check_file_existence(self, filename_dataset): + """ Check if all exported file exists. """ + # Path to the exported dataset + dataset_path = os.path.join(cfg.synthetic_dataroot, + self.dataset_name + ".h5") + + flag = True + # Open the h5 dataset + with h5py.File(dataset_path, "r") as f: + # Iterate through all the primitives + for prim_name in f.keys(): + if (len(filename_dataset[prim_name]) + != len(f[prim_name].keys())): + flag = False + + return flag + + def _check_primitive_and_index(self, primitive, index): + """ Check if the primitve and index are valid. """ + # Check primitives + if not primitive in self.available_primitives: + raise ValueError( + "[Error] The primitive is not in available primitives.") + + prim_len = len(self.filename_dataset[primitive]) + # Check the index + if not index < prim_len: + raise ValueError( + "[Error] The index exceeds the total file counts %d for %s" + % (prim_len, primitive)) diff --git a/imcui/third_party/SOLD2/sold2/dataset/synthetic_util.py b/imcui/third_party/SOLD2/sold2/dataset/synthetic_util.py new file mode 100644 index 0000000000000000000000000000000000000000..af009e0ce7e91391e31d7069064ae6121aa84cc0 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/synthetic_util.py @@ -0,0 +1,1232 @@ +""" +Code adapted from https://github.com/rpautrat/SuperPoint +Module used to generate geometrical synthetic shapes +""" +import math +import cv2 as cv +import numpy as np +import shapely.geometry +from itertools import combinations + +random_state = np.random.RandomState(None) + + +def set_random_state(state): + global random_state + random_state = state + + +def get_random_color(background_color): + """ Output a random scalar in grayscale with a least a small contrast + with the background color. """ + color = random_state.randint(256) + if abs(color - background_color) < 30: # not enough contrast + color = (color + 128) % 256 + return color + + +def get_different_color(previous_colors, min_dist=50, max_count=20): + """ Output a color that contrasts with the previous colors. + Parameters: + previous_colors: np.array of the previous colors + min_dist: the difference between the new color and + the previous colors must be at least min_dist + max_count: maximal number of iterations + """ + color = random_state.randint(256) + count = 0 + while np.any(np.abs(previous_colors - color) < min_dist) and count < max_count: + count += 1 + color = random_state.randint(256) + return color + + +def add_salt_and_pepper(img): + """ Add salt and pepper noise to an image. """ + noise = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) + cv.randu(noise, 0, 255) + black = noise < 30 + white = noise > 225 + img[white > 0] = 255 + img[black > 0] = 0 + cv.blur(img, (5, 5), img) + return np.empty((0, 2), dtype=np.int) + + +def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01, + max_rad_ratio=0.05, min_kernel_size=50, + max_kernel_size=300): + """ Generate a customized background image. + Parameters: + size: size of the image + nb_blobs: number of circles to draw + min_rad_ratio: the radius of blobs is at least min_rad_size * max(size) + max_rad_ratio: the radius of blobs is at most max_rad_size * max(size) + min_kernel_size: minimal size of the kernel + max_kernel_size: maximal size of the kernel + """ + img = np.zeros(size, dtype=np.uint8) + dim = max(size) + cv.randu(img, 0, 255) + cv.threshold(img, random_state.randint(256), 255, cv.THRESH_BINARY, img) + background_color = int(np.mean(img)) + blobs = np.concatenate( + [random_state.randint(0, size[1], size=(nb_blobs, 1)), + random_state.randint(0, size[0], size=(nb_blobs, 1))], axis=1) + for i in range(nb_blobs): + col = get_random_color(background_color) + cv.circle(img, (blobs[i][0], blobs[i][1]), + np.random.randint(int(dim * min_rad_ratio), + int(dim * max_rad_ratio)), + col, -1) + kernel_size = random_state.randint(min_kernel_size, max_kernel_size) + cv.blur(img, (kernel_size, kernel_size), img) + return img + + +def generate_custom_background(size, background_color, nb_blobs=3000, + kernel_boundaries=(50, 100)): + """ Generate a customized background to fill the shapes. + Parameters: + background_color: average color of the background image + nb_blobs: number of circles to draw + kernel_boundaries: interval of the possible sizes of the kernel + """ + img = np.zeros(size, dtype=np.uint8) + img = img + get_random_color(background_color) + blobs = np.concatenate( + [np.random.randint(0, size[1], size=(nb_blobs, 1)), + np.random.randint(0, size[0], size=(nb_blobs, 1))], axis=1) + for i in range(nb_blobs): + col = get_random_color(background_color) + cv.circle(img, (blobs[i][0], blobs[i][1]), + np.random.randint(20), col, -1) + kernel_size = np.random.randint(kernel_boundaries[0], + kernel_boundaries[1]) + cv.blur(img, (kernel_size, kernel_size), img) + return img + + +def final_blur(img, kernel_size=(5, 5)): + """ Gaussian blur applied to an image. + Parameters: + kernel_size: size of the kernel + """ + cv.GaussianBlur(img, kernel_size, 0, img) + + +def ccw(A, B, C, dim): + """ Check if the points are listed in counter-clockwise order. """ + if dim == 2: # only 2 dimensions + return((C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0]) + > (B[:, 1] - A[:, 1]) * (C[:, 0] - A[:, 0])) + else: # dim should be equal to 3 + return((C[:, 1, :] - A[:, 1, :]) + * (B[:, 0, :] - A[:, 0, :]) + > (B[:, 1, :] - A[:, 1, :]) + * (C[:, 0, :] - A[:, 0, :])) + + +def intersect(A, B, C, D, dim): + """ Return true if line segments AB and CD intersect """ + return np.any((ccw(A, C, D, dim) != ccw(B, C, D, dim)) & + (ccw(A, B, C, dim) != ccw(A, B, D, dim))) + + +def keep_points_inside(points, size): + """ Keep only the points whose coordinates are inside the dimensions of + the image of size 'size' """ + mask = (points[:, 0] >= 0) & (points[:, 0] < size[1]) &\ + (points[:, 1] >= 0) & (points[:, 1] < size[0]) + return points[mask, :] + + +def get_unique_junctions(segments, min_label_len): + """ Get unique junction points from line segments. """ + # Get all junctions from segments + junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, segments) + + return junc_points, line_map + + +def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray: + """ Get line map given the points and segment sets. """ + # create empty line map + num_point = points.shape[0] + line_map = np.zeros([num_point, num_point]) + + # Iterate through every segment + for idx in range(segments.shape[0]): + # Get the junctions from a single segement + seg = segments[idx, :] + junction1 = seg[:2] + junction2 = seg[2:] + + # Get index + idx_junction1 = np.where((points == junction1).sum(axis=1) == 2)[0] + idx_junction2 = np.where((points == junction2).sum(axis=1) == 2)[0] + + # label the corresponding entries + line_map[idx_junction1, idx_junction2] = 1 + line_map[idx_junction2, idx_junction1] = 1 + + return line_map + + +def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1): + """ Get line heat map from junctions and line map. """ + # Make sure that the thickness is 1 + if not isinstance(thickness, int): + thickness = int(thickness) + + # If the junction points are not int => round them and convert to int + if not junctions.dtype == np.int: + junctions = (np.round(junctions)).astype(np.int) + + # Initialize empty map + heat_map = np.zeros(size) + + if junctions.shape[0] > 0: # If empty, just return zero map + # Iterate through all the junctions + for idx in range(junctions.shape[0]): + # if no connectivity, just skip it + if line_map[idx, :].sum() == 0: + continue + # Plot the line segment + else: + # Iterate through all the connected junctions + for idx2 in np.where(line_map[idx, :] == 1)[0]: + point1 = junctions[idx, :] + point2 = junctions[idx2, :] + + # Draw line + cv.line(heat_map, tuple(point1), tuple(point2), 1., thickness) + + return heat_map + + +def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32): + """ Draw random lines and output the positions of the pair of junctions + and line associativities. + Parameters: + nb_lines: maximal number of lines + """ + # Set line number and points placeholder + num_lines = random_state.randint(1, nb_lines) + segments = np.empty((0, 4), dtype=np.int) + points = np.empty((0, 2), dtype=np.int) + background_color = int(np.mean(img)) + min_dim = min(img.shape) + + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + # Generate lines one by one + for i in range(num_lines): + x1 = random_state.randint(img.shape[1]) + y1 = random_state.randint(img.shape[0]) + p1 = np.array([[x1, y1]]) + x2 = random_state.randint(img.shape[1]) + y2 = random_state.randint(img.shape[0]) + p2 = np.array([[x2, y2]]) + + # Check the length of the line + line_length = np.sqrt(np.sum((p1 - p2) ** 2)) + if line_length < min_len: + continue + + # Check that there is no overlap + if intersect(segments[:, 0:2], segments[:, 2:4], p1, p2, 2): + continue + + col = get_random_color(background_color) + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.02) + cv.line(img, (x1, y1), (x2, y2), col, thickness) + + # Only record the segments longer than min_label_len + seg_len = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + if seg_len >= min_label_len: + segments = np.concatenate([segments, + np.array([[x1, y1, x2, y2]])], axis=0) + points = np.concatenate([points, + np.array([[x1, y1], [x2, y2]])], axis=0) + + # If no line is drawn, recursively call the function + if points.shape[0] == 0: + return draw_lines(img, nb_lines, min_len, min_label_len) + + # Get the line associativity map + line_map = get_line_map(points, segments) + + return { + "points": points, + "line_map": line_map + } + + +def check_segment_len(segments, min_len=32): + """ Check if one of the segments is too short (True means too short). """ + point1_vec = segments[:, :2] + point2_vec = segments[:, 2:] + diff = point1_vec - point2_vec + + dist = np.sqrt(np.sum(diff ** 2, axis=1)) + if np.any(dist < min_len): + return True + else: + return False + + +def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64): + """ Draw a polygon with a random number of corners and return the position + of the junctions + line map. + Parameters: + max_sides: maximal number of sides + 1 + """ + num_corners = random_state.randint(3, max_sides) + min_dim = min(img.shape[0], img.shape[1]) + rad = max(random_state.rand() * min_dim / 2, min_dim / 10) + # Center of a circle + x = random_state.randint(rad, img.shape[1] - rad) + y = random_state.randint(rad, img.shape[0] - rad) + + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + # Sample num_corners points inside the circle + slices = np.linspace(0, 2 * math.pi, num_corners + 1) + angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) + for i in range(num_corners)] + points = np.array( + [[int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)), + int(y + max(random_state.rand(), 0.4) * rad * math.sin(a))] + for a in angles]) + + # Filter the points that are too close or that have an angle too flat + norms = [np.linalg.norm(points[(i-1) % num_corners, :] + - points[i, :]) for i in range(num_corners)] + mask = np.array(norms) > 0.01 + points = points[mask, :] + num_corners = points.shape[0] + corner_angles = [angle_between_vectors(points[(i-1) % num_corners, :] - + points[i, :], + points[(i+1) % num_corners, :] - + points[i, :]) + for i in range(num_corners)] + mask = np.array(corner_angles) < (2 * math.pi / 3) + points = points[mask, :] + num_corners = points.shape[0] + + # Get junction pairs from points + segments = np.zeros([0, 4]) + # Used to record all the segments no matter we are going to label it or not. + segments_raw = np.zeros([0, 4]) + for idx in range(num_corners): + if idx == (num_corners - 1): + p1 = points[idx] + p2 = points[0] + else: + p1 = points[idx] + p2 = points[idx + 1] + + segment = np.concatenate((p1, p2), axis=0) + # Only record the segments longer than min_label_len + seg_len = np.sqrt(np.sum((p1 - p2) ** 2)) + if seg_len >= min_label_len: + segments = np.concatenate((segments, segment[None, ...]), axis=0) + segments_raw = np.concatenate((segments_raw, segment[None, ...]), + axis=0) + + # If not enough corner, just regenerate one + if (num_corners < 3) or check_segment_len(segments_raw, min_len): + return draw_polygon(img, max_sides, min_len, min_label_len) + + # Get junctions from segments + junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + else: + junc_points = np.unique(junctions_all, axis=0) + + # Get the line map + line_map = get_line_map(junc_points, segments) + + corners = points.reshape((-1, 1, 2)) + col = get_random_color(int(np.mean(img))) + cv.fillPoly(img, [corners], col) + + return { + "points": junc_points, + "line_map": line_map + } + + +def overlap(center, rad, centers, rads): + """ Check that the circle with (center, rad) + doesn't overlap with the other circles. """ + flag = False + for i in range(len(rads)): + if np.linalg.norm(center - centers[i]) < rad + rads[i]: + flag = True + break + return flag + + +def angle_between_vectors(v1, v2): + """ Compute the angle (in rad) between the two vectors v1 and v2. """ + v1_u = v1 / np.linalg.norm(v1) + v2_u = v2 / np.linalg.norm(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + +def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32, + min_label_len=64, safe_margin=5, **extra): + """ Draw multiple polygons with a random number of corners + and return the junction points + line map. + Parameters: + max_sides: maximal number of sides + 1 + nb_polygons: maximal number of polygons + """ + segments = np.empty((0, 4), dtype=np.int) + label_segments = np.empty((0, 4), dtype=np.int) + centers = [] + rads = [] + points = np.empty((0, 2), dtype=np.int) + background_color = int(np.mean(img)) + + min_dim = min(img.shape[0], img.shape[1]) + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + if isinstance(safe_margin, float) and safe_margin <= 1.: + safe_margin = int(min_dim * safe_margin) + + # Sequentially generate polygons + for i in range(nb_polygons): + num_corners = random_state.randint(3, max_sides) + min_dim = min(img.shape[0], img.shape[1]) + + # Also add the real radius + rad = max(random_state.rand() * min_dim / 2, min_dim / 9) + rad_real = rad - safe_margin + + # Center of a circle + x = random_state.randint(rad, img.shape[1] - rad) + y = random_state.randint(rad, img.shape[0] - rad) + + # Sample num_corners points inside the circle + slices = np.linspace(0, 2 * math.pi, num_corners + 1) + angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) + for i in range(num_corners)] + + # Sample outer points and inner points + new_points = [] + new_points_real = [] + for a in angles: + x_offset = max(random_state.rand(), 0.4) + y_offset = max(random_state.rand(), 0.4) + new_points.append([int(x + x_offset * rad * math.cos(a)), + int(y + y_offset * rad * math.sin(a))]) + new_points_real.append( + [int(x + x_offset * rad_real * math.cos(a)), + int(y + y_offset * rad_real * math.sin(a))]) + new_points = np.array(new_points) + new_points_real = np.array(new_points_real) + + # Filter the points that are too close or that have an angle too flat + norms = [np.linalg.norm(new_points[(i-1) % num_corners, :] + - new_points[i, :]) + for i in range(num_corners)] + mask = np.array(norms) > 0.01 + new_points = new_points[mask, :] + new_points_real = new_points_real[mask, :] + + num_corners = new_points.shape[0] + corner_angles = [ + angle_between_vectors(new_points[(i-1) % num_corners, :] - + new_points[i, :], + new_points[(i+1) % num_corners, :] - + new_points[i, :]) + for i in range(num_corners)] + mask = np.array(corner_angles) < (2 * math.pi / 3) + new_points = new_points[mask, :] + new_points_real = new_points_real[mask, :] + num_corners = new_points.shape[0] + + # Not enough corners + if num_corners < 3: + continue + + # Segments for checking overlap (outer circle) + new_segments = np.zeros((1, 4, num_corners)) + new_segments[:, 0, :] = [new_points[i][0] for i in range(num_corners)] + new_segments[:, 1, :] = [new_points[i][1] for i in range(num_corners)] + new_segments[:, 2, :] = [new_points[(i+1) % num_corners][0] + for i in range(num_corners)] + new_segments[:, 3, :] = [new_points[(i+1) % num_corners][1] + for i in range(num_corners)] + + # Segments to record (inner circle) + new_segments_real = np.zeros((1, 4, num_corners)) + new_segments_real[:, 0, :] = [new_points_real[i][0] + for i in range(num_corners)] + new_segments_real[:, 1, :] = [new_points_real[i][1] + for i in range(num_corners)] + new_segments_real[:, 2, :] = [ + new_points_real[(i + 1) % num_corners][0] + for i in range(num_corners)] + new_segments_real[:, 3, :] = [ + new_points_real[(i + 1) % num_corners][1] + for i in range(num_corners)] + + # Check that the polygon will not overlap with pre-existing shapes + if intersect(segments[:, 0:2, None], segments[:, 2:4, None], + new_segments[:, 0:2, :], new_segments[:, 2:4, :], + 3) or overlap(np.array([x, y]), rad, centers, rads): + continue + + # Check that the the edges of the polygon is not too short + if check_segment_len(new_segments_real, min_len): + continue + + # If the polygon is valid, append it to the polygon set + centers.append(np.array([x, y])) + rads.append(rad) + new_segments = np.reshape(np.swapaxes(new_segments, 0, 2), (-1, 4)) + segments = np.concatenate([segments, new_segments], axis=0) + + # Only record the segments longer than min_label_len + new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2), + (-1, 4)) + points1 = new_segments_real[:, :2] + points2 = new_segments_real[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + new_label_segment = new_segments_real[seg_len >= min_label_len, :] + label_segments = np.concatenate([label_segments, new_label_segment], + axis=0) + + # Color the polygon with a custom background + corners = new_points_real.reshape((-1, 1, 2)) + mask = np.zeros(img.shape, np.uint8) + custom_background = generate_custom_background( + img.shape, background_color, **extra) + + cv.fillPoly(mask, [corners], 255) + locs = np.where(mask != 0) + img[locs[0], locs[1]] = custom_background[locs[0], locs[1]] + points = np.concatenate([points, new_points], axis=0) + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + else: + junc_points = np.unique(junctions_all, axis=0) + + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_ellipses(img, nb_ellipses=20): + """ Draw several ellipses. + Parameters: + nb_ellipses: maximal number of ellipses + """ + centers = np.empty((0, 2), dtype=np.int) + rads = np.empty((0, 1), dtype=np.int) + min_dim = min(img.shape[0], img.shape[1]) / 4 + background_color = int(np.mean(img)) + for i in range(nb_ellipses): + ax = int(max(random_state.rand() * min_dim, min_dim / 5)) + ay = int(max(random_state.rand() * min_dim, min_dim / 5)) + max_rad = max(ax, ay) + x = random_state.randint(max_rad, img.shape[1] - max_rad) # center + y = random_state.randint(max_rad, img.shape[0] - max_rad) + new_center = np.array([[x, y]]) + + # Check that the ellipsis will not overlap with pre-existing shapes + diff = centers - new_center + if np.any(max_rad > (np.sqrt(np.sum(diff * diff, axis=1)) - rads)): + continue + centers = np.concatenate([centers, new_center], axis=0) + rads = np.concatenate([rads, np.array([[max_rad]])], axis=0) + + col = get_random_color(background_color) + angle = random_state.rand() * 90 + cv.ellipse(img, (x, y), (ax, ay), angle, 0, 360, col, -1) + return np.empty((0, 2), dtype=np.int) + + +def draw_star(img, nb_branches=6, min_len=32, min_label_len=64): + """ Draw a star and return the junction points + line map. + Parameters: + nb_branches: number of branches of the star + """ + num_branches = random_state.randint(3, nb_branches) + min_dim = min(img.shape[0], img.shape[1]) + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.025) + rad = max(random_state.rand() * min_dim / 2, min_dim / 5) + x = random_state.randint(rad, img.shape[1] - rad) + y = random_state.randint(rad, img.shape[0] - rad) + # Sample num_branches points inside the circle + slices = np.linspace(0, 2 * math.pi, num_branches + 1) + angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i]) + for i in range(num_branches)] + points = np.array( + [[int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)), + int(y + max(random_state.rand(), 0.3) * rad * math.sin(a))] + for a in angles]) + points = np.concatenate(([[x, y]], points), axis=0) + + # Generate segments and check the length + segments = np.array([[x, y, _[0], _[1]] for _ in points[1:, :]]) + if check_segment_len(segments, min_len): + return draw_star(img, nb_branches, min_len, min_label_len) + + # Only record the segments longer than min_label_len + points1 = segments[:, :2] + points2 = segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + background_color = int(np.mean(img)) + for i in range(1, num_branches + 1): + col = get_random_color(background_color) + cv.line(img, (points[0][0], points[0][1]), + (points[i][0], points[i][1]), + col, thickness) + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7, + transform_params=(0.05, 0.15), + min_label_len=64, seed=None): + """ Draw a checkerboard and output the junctions + line segments + Parameters: + max_rows: maximal number of rows + 1 + max_cols: maximal number of cols + 1 + transform_params: set the range of the parameters of the transformations + """ + if seed is None: + global random_state + else: + random_state = np.random.RandomState(seed) + + background_color = int(np.mean(img)) + + min_dim = min(img.shape) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + # Create the grid + rows = random_state.randint(3, max_rows) # number of rows + cols = random_state.randint(3, max_cols) # number of cols + s = min((img.shape[1] - 1) // cols, (img.shape[0] - 1) // rows) + x_coord = np.tile(range(cols + 1), + rows + 1).reshape(((rows + 1) * (cols + 1), 1)) + y_coord = np.repeat(range(rows + 1), + cols + 1).reshape(((rows + 1) * (cols + 1), 1)) + # points are the grid coordinates + points = s * np.concatenate([x_coord, y_coord], axis=1) + + # Warp the grid using an affine transformation and an homography + alpha_affine = np.max(img.shape) * ( + transform_params[0] + random_state.rand() * transform_params[1]) + center_square = np.float32(img.shape) // 2 + min_dim = min(img.shape) + square_size = min_dim // 3 + pts1 = np.float32([center_square + square_size, + [center_square[0] + square_size, + center_square[1] - square_size], + center_square - square_size, + [center_square[0] - square_size, + center_square[1] + square_size]]) + pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, + size=pts1.shape).astype(np.float32) + affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3]) + pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2, + size=pts1.shape).astype(np.float32) + perspective_transform = cv.getPerspectiveTransform(pts1, pts2) + + # Apply the affine transformation + points = np.transpose(np.concatenate( + (points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1)) + warped_points = np.transpose(np.dot(affine_transform, points)) + + # Apply the homography + warped_col0 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[0, :2]), axis=1), + perspective_transform[0, 2]) + warped_col1 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[1, :2]), axis=1), + perspective_transform[1, 2]) + warped_col2 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[2, :2]), axis=1), + perspective_transform[2, 2]) + warped_col0 = np.divide(warped_col0, warped_col2) + warped_col1 = np.divide(warped_col1, warped_col2) + warped_points = np.concatenate( + [warped_col0[:, None], warped_col1[:, None]], axis=1) + warped_points_float = warped_points.copy() + warped_points = warped_points.astype(int) + + # Fill the rectangles + colors = np.zeros((rows * cols,), np.int32) + for i in range(rows): + for j in range(cols): + # Get a color that contrast with the neighboring cells + if i == 0 and j == 0: + col = get_random_color(background_color) + else: + neighboring_colors = [] + if i != 0: + neighboring_colors.append(colors[(i - 1) * cols + j]) + if j != 0: + neighboring_colors.append(colors[i * cols + j - 1]) + col = get_different_color(np.array(neighboring_colors)) + colors[i * cols + j] = col + + # Fill the cell + cv.fillConvexPoly(img, np.array( + [(warped_points[i * (cols + 1) + j, 0], + warped_points[i * (cols + 1) + j, 1]), + (warped_points[i * (cols + 1) + j + 1, 0], + warped_points[i * (cols + 1) + j + 1, 1]), + (warped_points[(i + 1) * (cols + 1) + j + 1, 0], + warped_points[(i + 1) * (cols + 1) + j + 1, 1]), + (warped_points[(i + 1) * (cols + 1) + j, 0], + warped_points[(i + 1) * (cols + 1) + j, 1])]), col) + + label_segments = np.empty([0, 4], dtype=np.int) + # Iterate through rows + for row_idx in range(rows + 1): + # Include all the combination of the junctions + # Iterate through all the combination of junction index in that row + multi_seg_lst = [ + np.array([warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + label_segments = np.concatenate((label_segments, multi_seg), axis=0) + + # Iterate through columns + for col_idx in range(cols + 1): # for 5 columns, we will have 5 + 1 edges + # Include all the combination of the junctions + # Iterate throuhg all the combination of junction index in that column + multi_seg_lst = [ + np.array([warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + label_segments = np.concatenate((label_segments, multi_seg), axis=0) + + label_segments_filtered = np.zeros([0, 4]) + # Define image boundary polygon (in x y manner) + image_poly = shapely.geometry.Polygon( + [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1], + [0, img.shape[0] - 1]]) + for idx in range(label_segments.shape[0]): + # Get the line segment + seg_raw = label_segments[idx, :] + seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + label_segments_filtered = np.concatenate( + (label_segments_filtered, seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array(seg.intersection( + image_poly).coords).reshape([-1, 4]) + # If intersect with eact one point + except: + continue + segment = p + label_segments_filtered = np.concatenate( + (label_segments_filtered, segment), axis=0) + + else: + continue + + label_segments = np.round(label_segments_filtered).astype(np.int) + + # Only record the segments longer than min_label_len + points1 = label_segments[:, :2] + points2 = label_segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = label_segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junc_points, line_map = get_unique_junctions(label_segments, + min_label_len) + + # Draw lines on the boundaries of the board at random + nb_rows = random_state.randint(2, rows + 2) + nb_cols = random_state.randint(2, cols + 2) + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.015) + for _ in range(nb_rows): + row_idx = random_state.randint(rows + 1) + col_idx1 = random_state.randint(cols + 1) + col_idx2 = random_state.randint(cols + 1) + col = get_random_color(background_color) + cv.line(img, (warped_points[row_idx * (cols + 1) + col_idx1, 0], + warped_points[row_idx * (cols + 1) + col_idx1, 1]), + (warped_points[row_idx * (cols + 1) + col_idx2, 0], + warped_points[row_idx * (cols + 1) + col_idx2, 1]), + col, thickness) + for _ in range(nb_cols): + col_idx = random_state.randint(cols + 1) + row_idx1 = random_state.randint(rows + 1) + row_idx2 = random_state.randint(rows + 1) + col = get_random_color(background_color) + cv.line(img, (warped_points[row_idx1 * (cols + 1) + col_idx, 0], + warped_points[row_idx1 * (cols + 1) + col_idx, 1]), + (warped_points[row_idx2 * (cols + 1) + col_idx, 0], + warped_points[row_idx2 * (cols + 1) + col_idx, 1]), + col, thickness) + + # Keep only the points inside the image + points = keep_points_inside(warped_points, img.shape[:2]) + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64, + transform_params=(0.05, 0.15), seed=None): + """ Draw stripes in a distorted rectangle + and output the junctions points + line map. + Parameters: + max_nb_cols: maximal number of stripes to be drawn + min_width_ratio: the minimal width of a stripe is + min_width_ratio * smallest dimension of the image + transform_params: set the range of the parameters of the transformations + """ + # Set the optional random seed (most for debugging) + if seed is None: + global random_state + else: + random_state = np.random.RandomState(seed) + + background_color = int(np.mean(img)) + # Create the grid + board_size = (int(img.shape[0] * (1 + random_state.rand())), + int(img.shape[1] * (1 + random_state.rand()))) + + # Number of cols + col = random_state.randint(5, max_nb_cols) + cols = np.concatenate([board_size[1] * random_state.rand(col - 1), + np.array([0, board_size[1] - 1])], axis=0) + cols = np.unique(cols.astype(int)) + + # Remove the indices that are too close + min_dim = min(img.shape) + + # Convert length constrain to pixel if given float number + if isinstance(min_len, float) and min_len <= 1.: + min_len = int(min_dim * min_len) + if isinstance(min_label_len, float) and min_label_len <= 1.: + min_label_len = int(min_dim * min_label_len) + + cols = cols[(np.concatenate([cols[1:], + np.array([board_size[1] + min_len])], + axis=0) - cols) >= min_len] + # Update the number of cols + col = cols.shape[0] - 1 + cols = np.reshape(cols, (col + 1, 1)) + cols1 = np.concatenate([cols, np.zeros((col + 1, 1), np.int32)], axis=1) + cols2 = np.concatenate( + [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1) + points = np.concatenate([cols1, cols2], axis=0) + + # Warp the grid using an affine transformation and a homography + alpha_affine = np.max(img.shape) * ( + transform_params[0] + random_state.rand() * transform_params[1]) + center_square = np.float32(img.shape) // 2 + square_size = min(img.shape) // 3 + pts1 = np.float32([center_square + square_size, + [center_square[0]+square_size, + center_square[1]-square_size], + center_square - square_size, + [center_square[0]-square_size, + center_square[1]+square_size]]) + pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, + size=pts1.shape).astype(np.float32) + affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3]) + pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2, + size=pts1.shape).astype(np.float32) + perspective_transform = cv.getPerspectiveTransform(pts1, pts2) + + # Apply the affine transformation + points = np.transpose(np.concatenate((points, + np.ones((2 * (col + 1), 1))), + axis=1)) + warped_points = np.transpose(np.dot(affine_transform, points)) + + # Apply the homography + warped_col0 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[0, :2]), axis=1), + perspective_transform[0, 2]) + warped_col1 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[1, :2]), axis=1), + perspective_transform[1, 2]) + warped_col2 = np.add(np.sum(np.multiply( + warped_points, perspective_transform[2, :2]), axis=1), + perspective_transform[2, 2]) + warped_col0 = np.divide(warped_col0, warped_col2) + warped_col1 = np.divide(warped_col1, warped_col2) + warped_points = np.concatenate( + [warped_col0[:, None], warped_col1[:, None]], axis=1) + warped_points_float = warped_points.copy() + warped_points = warped_points.astype(int) + + # Fill the rectangles and get the segments + color = get_random_color(background_color) + # segments_debug = np.zeros([0, 4]) + for i in range(col): + # Fill the color + color = (color + 128 + random_state.randint(-30, 30)) % 256 + cv.fillConvexPoly(img, np.array([(warped_points[i, 0], + warped_points[i, 1]), + (warped_points[i+1, 0], + warped_points[i+1, 1]), + (warped_points[i+col+2, 0], + warped_points[i+col+2, 1]), + (warped_points[i+col+1, 0], + warped_points[i+col+1, 1])]), + color) + + segments = np.zeros([0, 4]) + row = 1 # in stripes case + # Iterate through rows + for row_idx in range(row + 1): + # Include all the combination of the junctions + # Iterate through all the combination of junction index in that row + multi_seg_lst = [np.array( + [warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + segments = np.concatenate((segments, multi_seg), axis=0) + + # Iterate through columns + for col_idx in range(col + 1): # for 5 columns, we will have 5 + 1 edges. + # Include all the combination of the junctions + # Iterate throuhg all the combination of junction index in that column + multi_seg_lst = [np.array( + [warped_points_float[id1, 0], + warped_points_float[id1, 1], + warped_points_float[id2, 0], + warped_points_float[id2, 1]])[None, ...] + for (id1, id2) in combinations(range( + col_idx, col_idx + (row * col) + 2, col + 1), 2)] + multi_seg = np.concatenate(multi_seg_lst, axis=0) + segments = np.concatenate((segments, multi_seg), axis=0) + + # Select and refine the segments + segments_new = np.zeros([0, 4]) + # Define image boundary polygon (in x y manner) + image_poly = shapely.geometry.Polygon( + [[0, 0], [img.shape[1]-1, 0], [img.shape[1]-1, img.shape[0]-1], + [0, img.shape[0]-1]]) + for idx in range(segments.shape[0]): + # Get the line segment + seg_raw = segments[idx, :] + seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + segments_new = np.concatenate( + (segments_new, seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + # If intersect at exact one point, just continue. + except: + continue + segment = p + segments_new = np.concatenate((segments_new, segment), axis=0) + + else: + continue + + segments = (np.round(segments_new)).astype(np.int) + + # Only record the segments longer than min_label_len + points1 = segments[:, :2] + points2 = segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + # Draw lines on the boundaries of the stripes at random + nb_rows = random_state.randint(2, 5) + nb_cols = random_state.randint(2, col + 2) + thickness = random_state.randint(min_dim * 0.01, min_dim * 0.011) + for _ in range(nb_rows): + row_idx = random_state.choice([0, col + 1]) + col_idx1 = random_state.randint(col + 1) + col_idx2 = random_state.randint(col + 1) + color = get_random_color(background_color) + cv.line(img, (warped_points[row_idx + col_idx1, 0], + warped_points[row_idx + col_idx1, 1]), + (warped_points[row_idx + col_idx2, 0], + warped_points[row_idx + col_idx2, 1]), + color, thickness) + + for _ in range(nb_cols): + col_idx = random_state.randint(col + 1) + color = get_random_color(background_color) + cv.line(img, (warped_points[col_idx, 0], + warped_points[col_idx, 1]), + (warped_points[col_idx + col + 1, 0], + warped_points[col_idx + col + 1, 1]), + color, thickness) + + # Keep only the points inside the image + # points = keep_points_inside(warped_points, img.shape[:2]) + return { + "points": junc_points, + "line_map": line_map + } + + +def draw_cube(img, min_size_ratio=0.2, min_label_len=64, + scale_interval=(0.4, 0.6), trans_interval=(0.5, 0.2)): + """ Draw a 2D projection of a cube and output the visible juntions. + Parameters: + min_size_ratio: min(img.shape) * min_size_ratio is the smallest + achievable cube side size + scale_interval: the scale is between scale_interval[0] and + scale_interval[0]+scale_interval[1] + trans_interval: the translation is between img.shape*trans_interval[0] + and img.shape*(trans_interval[0] + trans_interval[1]) + """ + # Generate a cube and apply to it an affine transformation + # The order matters! + # The indices of two adjacent vertices differ only of one bit (Gray code) + background_color = int(np.mean(img)) + min_dim = min(img.shape[:2]) + min_side = min_dim * min_size_ratio + lx = min_side + random_state.rand() * 2 * min_dim / 3 # dims of the cube + ly = min_side + random_state.rand() * 2 * min_dim / 3 + lz = min_side + random_state.rand() * 2 * min_dim / 3 + cube = np.array([[0, 0, 0], + [lx, 0, 0], + [0, ly, 0], + [lx, ly, 0], + [0, 0, lz], + [lx, 0, lz], + [0, ly, lz], + [lx, ly, lz]]) + rot_angles = random_state.rand(3) * 3 * math.pi / 10. + math.pi / 10. + rotation_1 = np.array([[math.cos(rot_angles[0]), + -math.sin(rot_angles[0]), 0], + [math.sin(rot_angles[0]), + math.cos(rot_angles[0]), 0], + [0, 0, 1]]) + rotation_2 = np.array([[1, 0, 0], + [0, math.cos(rot_angles[1]), + -math.sin(rot_angles[1])], + [0, math.sin(rot_angles[1]), + math.cos(rot_angles[1])]]) + rotation_3 = np.array([[math.cos(rot_angles[2]), 0, + -math.sin(rot_angles[2])], + [0, 1, 0], + [math.sin(rot_angles[2]), 0, + math.cos(rot_angles[2])]]) + scaling = np.array([[scale_interval[0] + + random_state.rand() * scale_interval[1], 0, 0], + [0, scale_interval[0] + + random_state.rand() * scale_interval[1], 0], + [0, 0, scale_interval[0] + + random_state.rand() * scale_interval[1]]]) + trans = np.array([img.shape[1] * trans_interval[0] + + random_state.randint(-img.shape[1] * trans_interval[1], + img.shape[1] * trans_interval[1]), + img.shape[0] * trans_interval[0] + + random_state.randint(-img.shape[0] * trans_interval[1], + img.shape[0] * trans_interval[1]), + 0]) + cube = trans + np.transpose( + np.dot(scaling, np.dot(rotation_1, + np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube)))))) + + # The hidden corner is 0 by construction + # The front one is 7 + cube = cube[:, :2] # project on the plane z=0 + cube = cube.astype(int) + points = cube[1:, :] # get rid of the hidden corner + + # Get the three visible faces + faces = np.array([[7, 3, 1, 5], [7, 5, 4, 6], [7, 6, 2, 3]]) + + # Get all visible line segments + segments = np.zeros([0, 4]) + # Iterate through all the faces + for face_idx in range(faces.shape[0]): + face = faces[face_idx, :] + # Brute-forcely expand all the segments + segment = np.array( + [np.concatenate((cube[face[0]], cube[face[1]]), axis=0), + np.concatenate((cube[face[1]], cube[face[2]]), axis=0), + np.concatenate((cube[face[2]], cube[face[3]]), axis=0), + np.concatenate((cube[face[3]], cube[face[0]]), axis=0)]) + segments = np.concatenate((segments, segment), axis=0) + + # Select and refine the segments + segments_new = np.zeros([0, 4]) + # Define image boundary polygon (in x y manner) + image_poly = shapely.geometry.Polygon( + [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1], + [0, img.shape[0] - 1]]) + for idx in range(segments.shape[0]): + # Get the line segment + seg_raw = segments[idx, :] + seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + segments_new = np.concatenate( + (segments_new, seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + except: + continue + segment = p + segments_new = np.concatenate((segments_new, segment), axis=0) + + else: + continue + + segments = (np.round(segments_new)).astype(np.int) + + # Only record the segments longer than min_label_len + points1 = segments[:, :2] + points2 = segments[:, 2:] + seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1)) + label_segments = segments[seg_len >= min_label_len, :] + + # Get all junctions from label segments + junctions_all = np.concatenate( + (label_segments[:, :2], label_segments[:, 2:]), axis=0) + if junctions_all.shape[0] == 0: + junc_points = None + line_map = None + + # Get all unique junction points + else: + junc_points = np.unique(junctions_all, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junc_points, label_segments) + + # Fill the faces and draw the contours + col_face = get_random_color(background_color) + for i in [0, 1, 2]: + cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))], + col_face) + thickness = random_state.randint(min_dim * 0.003, min_dim * 0.015) + for i in [0, 1, 2]: + for j in [0, 1, 2, 3]: + col_edge = (col_face + 128 + + random_state.randint(-64, 64))\ + % 256 # color that constrats with the face color + cv.line(img, (cube[faces[i][j], 0], cube[faces[i][j], 1]), + (cube[faces[i][(j + 1) % 4], 0], + cube[faces[i][(j + 1) % 4], 1]), + col_edge, thickness) + + return { + "points": junc_points, + "line_map": line_map + } + + +def gaussian_noise(img): + """ Apply random noise to the image. """ + cv.randu(img, 0, 255) + return { + "points": None, + "line_map": None + } diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/__init__.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d9338abb169f7a86f3c6e702a031e1c0de86c339 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py @@ -0,0 +1,350 @@ +""" +This file implements the homographic transforms for data augmentation. +Code adapted from https://github.com/rpautrat/SuperPoint +""" +import numpy as np +from math import pi + +from ..synthetic_util import get_line_map, get_line_heatmap +import cv2 +import copy +import shapely.geometry + + +def sample_homography( + shape, perspective=True, scaling=True, rotation=True, + translation=True, n_scales=5, n_angles=25, scaling_amplitude=0.1, + perspective_amplitude_x=0.1, perspective_amplitude_y=0.1, + patch_ratio=0.5, max_angle=pi/2, allow_artifacts=False, + translation_overflow=0.): + """ + Computes the homography transformation between a random patch in the + original image and a warped projection with the same image size. + As in `tf.contrib.image.transform`, it maps the output point + (warped patch) to a transformed input point (original patch). + The original patch, initialized with a simple half-size centered crop, + is iteratively projected, scaled, rotated and translated. + + Arguments: + shape: A rank-2 `Tensor` specifying the height and width of the original image. + perspective: A boolean that enables the perspective and affine transformations. + scaling: A boolean that enables the random scaling of the patch. + rotation: A boolean that enables the random rotation of the patch. + translation: A boolean that enables the random translation of the patch. + n_scales: The number of tentative scales that are sampled when scaling. + n_angles: The number of tentatives angles that are sampled when rotating. + scaling_amplitude: Controls the amount of scale. + perspective_amplitude_x: Controls the perspective effect in x direction. + perspective_amplitude_y: Controls the perspective effect in y direction. + patch_ratio: Controls the size of the patches used to create the homography. + max_angle: Maximum angle used in rotations. + allow_artifacts: A boolean that enables artifacts when applying the homography. + translation_overflow: Amount of border artifacts caused by translation. + + Returns: + homo_mat: A numpy array of shape `[1, 3, 3]` corresponding to the + homography transform. + selected_scale: The selected scaling factor. + """ + # Convert shape to ndarry + if not isinstance(shape, np.ndarray): + shape = np.array(shape) + + # Corners of the output image + pts1 = np.array([[0., 0.], [0., 1.], [1., 1.], [1., 0.]]) + # Corners of the input patch + margin = (1 - patch_ratio) / 2 + pts2 = margin + np.array([[0, 0], [0, patch_ratio], + [patch_ratio, patch_ratio], [patch_ratio, 0]]) + + # Random perspective and affine perturbations + if perspective: + if not allow_artifacts: + perspective_amplitude_x = min(perspective_amplitude_x, margin) + perspective_amplitude_y = min(perspective_amplitude_y, margin) + + # normal distribution with mean=0, std=perspective_amplitude_y/2 + perspective_displacement = np.random.normal( + 0., perspective_amplitude_y/2, [1]) + h_displacement_left = np.random.normal( + 0., perspective_amplitude_x/2, [1]) + h_displacement_right = np.random.normal( + 0., perspective_amplitude_x/2, [1]) + pts2 += np.stack([np.concatenate([h_displacement_left, + perspective_displacement], 0), + np.concatenate([h_displacement_left, + -perspective_displacement], 0), + np.concatenate([h_displacement_right, + perspective_displacement], 0), + np.concatenate([h_displacement_right, + -perspective_displacement], 0)]) + + # Random scaling: sample several scales, check collision with borders, + # randomly pick a valid one + if scaling: + scales = np.concatenate( + [[1.], np.random.normal(1, scaling_amplitude/2, [n_scales])], 0) + center = np.mean(pts2, axis=0, keepdims=True) + scaled = (pts2 - center)[None, ...] * scales[..., None, None] + center + # all scales are valid except scale=1 + if allow_artifacts: + valid = np.array(range(n_scales)) + # Chech the valid scale + else: + valid = np.where(np.all((scaled >= 0.) + & (scaled < 1.), (1, 2)))[0] + # No valid scale found => recursively call + if valid.shape[0] == 0: + return sample_homography( + shape, perspective, scaling, rotation, translation, + n_scales, n_angles, scaling_amplitude, + perspective_amplitude_x, perspective_amplitude_y, + patch_ratio, max_angle, allow_artifacts, translation_overflow) + + idx = valid[np.random.uniform(0., valid.shape[0], ()).astype(np.int32)] + pts2 = scaled[idx] + + # Additionally save and return the selected scale. + selected_scale = scales[idx] + + # Random translation + if translation: + t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0) + if allow_artifacts: + t_min += translation_overflow + t_max += translation_overflow + pts2 += (np.stack([np.random.uniform(-t_min[0], t_max[0], ()), + np.random.uniform(-t_min[1], + t_max[1], ())]))[None, ...] + + # Random rotation: sample several rotations, check collision with borders, + # randomly pick a valid one + if rotation: + angles = np.linspace(-max_angle, max_angle, n_angles) + # in case no rotation is valid + angles = np.concatenate([[0.], angles], axis=0) + center = np.mean(pts2, axis=0, keepdims=True) + rot_mat = np.reshape(np.stack( + [np.cos(angles), -np.sin(angles), + np.sin(angles), np.cos(angles)], axis=1), [-1, 2, 2]) + rotated = np.matmul( + np.tile((pts2 - center)[None, ...], [n_angles+1, 1, 1]), + rot_mat) + center + if allow_artifacts: + # All angles are valid, except angle=0 + valid = np.array(range(n_angles)) + else: + valid = np.where(np.all((rotated >= 0.) + & (rotated < 1.), axis=(1, 2)))[0] + + if valid.shape[0] == 0: + return sample_homography( + shape, perspective, scaling, rotation, translation, + n_scales, n_angles, scaling_amplitude, + perspective_amplitude_x, perspective_amplitude_y, + patch_ratio, max_angle, allow_artifacts, translation_overflow) + + idx = valid[np.random.uniform(0., valid.shape[0], + ()).astype(np.int32)] + pts2 = rotated[idx] + + # Rescale to actual size + shape = shape[::-1].astype(np.float32) # different convention [y, x] + pts1 *= shape[None, ...] + pts2 *= shape[None, ...] + + def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] + + def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] + + a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) + for f in (ax, ay)], axis=0) + p_mat = np.transpose(np.stack([[pts2[i][j] for i in range(4) + for j in range(2)]], axis=0)) + homo_vec, _, _, _ = np.linalg.lstsq(a_mat, p_mat, rcond=None) + + # Compose the homography vector back to matrix + homo_mat = np.concatenate([ + homo_vec[0:3, 0][None, ...], homo_vec[3:6, 0][None, ...], + np.concatenate((homo_vec[6], homo_vec[7], [1]), + axis=0)[None, ...]], axis=0) + + return homo_mat, selected_scale + + +def convert_to_line_segments(junctions, line_map): + """ Convert junctions and line map to line segments. """ + # Copy the line map + line_map_tmp = copy.copy(line_map) + + line_segments = np.zeros([0, 4]) + for idx in range(junctions.shape[0]): + # If no connectivity, just skip it + if line_map_tmp[idx, :].sum() == 0: + continue + # Record the line segment + else: + for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: + p1 = junctions[idx, :] + p2 = junctions[idx2, :] + line_segments = np.concatenate( + (line_segments, + np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), + axis=0) + # Update line_map + line_map_tmp[idx, idx2] = 0 + line_map_tmp[idx2, idx] = 0 + + return line_segments + + +def compute_valid_mask(image_size, homography, + border_margin, valid_mask=None): + # Warp the mask + if valid_mask is None: + initial_mask = np.ones(image_size) + else: + initial_mask = valid_mask + mask = cv2.warpPerspective( + initial_mask, homography, (image_size[1], image_size[0]), + flags=cv2.INTER_NEAREST) + + # Optionally perform erosion + if border_margin > 0: + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (border_margin*2, )*2) + mask = cv2.erode(mask, kernel) + + # Perform dilation if border_margin is negative + if border_margin < 0: + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (abs(int(border_margin))*2, )*2) + mask = cv2.dilate(mask, kernel) + + return mask + + +def warp_line_segment(line_segments, homography, image_size): + """ Warp the line segments using a homography. """ + # Separate the line segements into 2N points to apply matrix operation + num_segments = line_segments.shape[0] + + junctions = np.concatenate( + (line_segments[:, :2], # The first junction of each segment. + line_segments[:, 2:]), # The second junction of each segment. + axis=0) + # Convert to homogeneous coordinates + # Flip the junctions before converting to homogeneous (xy format) + junctions = np.flip(junctions, axis=1) + junctions = np.concatenate((junctions, np.ones([2*num_segments, 1])), + axis=1) + warped_junctions = np.matmul(homography, junctions.T).T + + # Convert back to segments + warped_junctions = warped_junctions[:, :2] / warped_junctions[:, 2:] + # (Convert back to hw format) + warped_junctions = np.flip(warped_junctions, axis=1) + warped_segments = np.concatenate( + (warped_junctions[:num_segments, :], + warped_junctions[num_segments:, :]), + axis=1 + ) + + # Check the intersections with the boundary + warped_segments_new = np.zeros([0, 4]) + image_poly = shapely.geometry.Polygon( + [[0, 0], [image_size[1]-1, 0], [image_size[1]-1, image_size[0]-1], + [0, image_size[0]-1]]) + for idx in range(warped_segments.shape[0]): + # Get the line segment + seg_raw = warped_segments[idx, :] # in HW format. + # Convert to shapely line (flip to xy format) + seg = shapely.geometry.LineString([np.flip(seg_raw[:2]), + np.flip(seg_raw[2:])]) + + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + warped_segments_new = np.concatenate((warped_segments_new, + seg_raw[None, ...]), axis=0) + + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + # If intersect at exact one point, just continue. + except: + continue + segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], + axis=0)])[None, ...] + warped_segments_new = np.concatenate( + (warped_segments_new, segment), axis=0) + + else: + continue + + warped_segments = (np.round(warped_segments_new)).astype(np.int) + return warped_segments + + +class homography_transform(object): + """ # Homography transformations. """ + def __init__(self, image_size, homograpy_config, + border_margin=0, min_label_len=20): + self.homo_config = homograpy_config + self.image_size = image_size + self.target_size = (self.image_size[1], self.image_size[0]) + self.border_margin = border_margin + if (min_label_len < 1) and isinstance(min_label_len, float): + raise ValueError("[Error] min_label_len should be in pixels.") + self.min_label_len = min_label_len + + def __call__(self, input_image, junctions, line_map, + valid_mask=None, homo=None, scale=None): + # Sample one random homography or use the given one + if homo is None or scale is None: + homo, scale = sample_homography(self.image_size, + **self.homo_config) + + # Warp the image + warped_image = cv2.warpPerspective( + input_image, homo, self.target_size, flags=cv2.INTER_LINEAR) + + valid_mask = compute_valid_mask(self.image_size, homo, + self.border_margin, valid_mask) + + # Convert junctions and line_map back to line segments + line_segments = convert_to_line_segments(junctions, line_map) + + # Warp the segments and check the length. + # Adjust the min_label_length + warped_segments = warp_line_segment(line_segments, homo, + self.image_size) + + # Convert back to junctions and line_map + junctions_new = np.concatenate((warped_segments[:, :2], + warped_segments[:, 2:]), axis=0) + if junctions_new.shape[0] == 0: + junctions_new = np.zeros([0, 2]) + line_map = np.zeros([0, 0]) + warped_heatmap = np.zeros(self.image_size) + else: + junctions_new = np.unique(junctions_new, axis=0) + + # Generate line map from points and segments + line_map = get_line_map(junctions_new, + warped_segments).astype(np.int) + # Compute the heatmap + warped_heatmap = get_line_heatmap(np.flip(junctions_new, axis=1), + line_map, self.image_size) + + return { + "junctions": junctions_new, + "warped_image": warped_image, + "valid_mask": valid_mask, + "line_map": line_map, + "warped_heatmap": warped_heatmap, + "homo": homo, + "scale": scale + } diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa44bf0efa93a47e5f8012988058f1cbd49324f --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py @@ -0,0 +1,185 @@ +""" +Common photometric transforms for data augmentation. +""" +import numpy as np +from PIL import Image +from torchvision import transforms as transforms +import cv2 + + +# List all the available augmentations +available_augmentations = [ + 'additive_gaussian_noise', + 'additive_speckle_noise', + 'random_brightness', + 'random_contrast', + 'additive_shade', + 'motion_blur' +] + + +class additive_gaussian_noise(object): + """ Additive gaussian noise. """ + def __init__(self, stddev_range=None): + # If std is not given, use the default setting + if stddev_range is None: + self.stddev_range = [5, 95] + else: + self.stddev_range = stddev_range + + def __call__(self, input_image): + # Get the noise stddev + stddev = np.random.uniform(self.stddev_range[0], self.stddev_range[1]) + noise = np.random.normal(0., stddev, size=input_image.shape) + noisy_image = (input_image + noise).clip(0., 255.) + + return noisy_image + + +class additive_speckle_noise(object): + """ Additive speckle noise. """ + def __init__(self, prob_range=None): + # If prob range is not given, use the default setting + if prob_range is None: + self.prob_range = [0.0, 0.005] + else: + self.prob_range = prob_range + + def __call__(self, input_image): + # Sample + prob = np.random.uniform(self.prob_range[0], self.prob_range[1]) + sample = np.random.uniform(0., 1., size=input_image.shape) + + # Get the mask + mask0 = sample <= prob + mask1 = sample >= (1 - prob) + + # Mask the image (here we assume the image ranges from 0~255 + noisy = input_image.copy() + noisy[mask0] = 0. + noisy[mask1] = 255. + + return noisy + + +class random_brightness(object): + """ Brightness change. """ + def __init__(self, brightness=None): + # If the brightness is not given, use the default setting + if brightness is None: + self.brightness = 0.5 + else: + self.brightness = brightness + + # Initialize the transformer + self.transform = transforms.ColorJitter(brightness=self.brightness) + + def __call__(self, input_image): + # Convert to PIL image + if isinstance(input_image, np.ndarray): + input_image = Image.fromarray(input_image.astype(np.uint8)) + + return np.array(self.transform(input_image)) + + +class random_contrast(object): + """ Additive contrast. """ + def __init__(self, contrast=None): + # If the brightness is not given, use the default setting + if contrast is None: + self.contrast = 0.5 + else: + self.contrast = contrast + + # Initialize the transformer + self.transform = transforms.ColorJitter(contrast=self.contrast) + + def __call__(self, input_image): + # Convert to PIL image + if isinstance(input_image, np.ndarray): + input_image = Image.fromarray(input_image.astype(np.uint8)) + + return np.array(self.transform(input_image)) + + +class additive_shade(object): + """ Additive shade. """ + def __init__(self, nb_ellipses=20, transparency_range=None, + kernel_size_range=None): + self.nb_ellipses = nb_ellipses + if transparency_range is None: + self.transparency_range = [-0.5, 0.8] + else: + self.transparency_range = transparency_range + + if kernel_size_range is None: + self.kernel_size_range = [250, 350] + else: + self.kernel_size_range = kernel_size_range + + def __call__(self, input_image): + # ToDo: if we should convert to numpy array first. + min_dim = min(input_image.shape[:2]) / 4 + mask = np.zeros(input_image.shape[:2], np.uint8) + for i in range(self.nb_ellipses): + ax = int(max(np.random.rand() * min_dim, min_dim / 5)) + ay = int(max(np.random.rand() * min_dim, min_dim / 5)) + max_rad = max(ax, ay) + x = np.random.randint(max_rad, input_image.shape[1] - max_rad) + y = np.random.randint(max_rad, input_image.shape[0] - max_rad) + angle = np.random.rand() * 90 + cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1) + + transparency = np.random.uniform(*self.transparency_range) + kernel_size = np.random.randint(*self.kernel_size_range) + + # kernel_size has to be odd + if (kernel_size % 2) == 0: + kernel_size += 1 + mask = cv2.GaussianBlur(mask.astype(np.float32), + (kernel_size, kernel_size), 0) + shaded = (input_image[..., None] + * (1 - transparency * mask[..., np.newaxis]/255.)) + shaded = np.clip(shaded, 0, 255) + + return np.reshape(shaded, input_image.shape) + + +class motion_blur(object): + """ Motion blur. """ + def __init__(self, max_kernel_size=10): + self.max_kernel_size = max_kernel_size + + def __call__(self, input_image): + # Either vertical, horizontal or diagonal blur + mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) + ksize = np.random.randint( + 0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1 + center = int((ksize - 1) / 2) + kernel = np.zeros((ksize, ksize)) + if mode == 'h': + kernel[center, :] = 1. + elif mode == 'v': + kernel[:, center] = 1. + elif mode == 'diag_down': + kernel = np.eye(ksize) + elif mode == 'diag_up': + kernel = np.flip(np.eye(ksize), 0) + var = ksize * ksize / 16. + grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) + gaussian = np.exp(-(np.square(grid - center) + + np.square(grid.T - center)) / (2. * var)) + kernel *= gaussian + kernel /= np.sum(kernel) + blurred = cv2.filter2D(input_image, -1, kernel) + + return np.reshape(blurred, input_image.shape) + + +class normalize_image(object): + """ Image normalization to the range [0, 1]. """ + def __init__(self): + self.normalize_value = 255 + + def __call__(self, input_image): + return (input_image / self.normalize_value).astype(np.float32) diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/utils.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1ed09e5b32e2ae2f3577e0e8e5491495e7b05b --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/transforms/utils.py @@ -0,0 +1,121 @@ +""" +Some useful functions for dataset pre-processing +""" +import cv2 +import numpy as np +import shapely.geometry as sg + +from ..synthetic_util import get_line_map +from . import homographic_transforms as homoaug + + +def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0): + H, W = image.shape[:2] + H_scale, W_scale = round(H * scale), round(W * scale) + + # Nothing to do if the scale is too close to 1 + if H_scale == H and W_scale == W: + return (image, junctions, line_map, np.ones([H, W], dtype=np.int)) + + # Zoom-in => resize and random crop + if scale >= 1.: + image_big = cv2.resize(image, (W_scale, H_scale), + interpolation=cv2.INTER_LINEAR) + # Crop the image + image = image_big[h_crop:h_crop+H, w_crop:w_crop+W, ...] + valid_mask = np.ones([H, W], dtype=np.int) + + # Process junctions + junctions, line_map = process_junctions_and_line_map( + h_crop, w_crop, H, W, H_scale, W_scale, + junctions, line_map, "zoom-in") + # Zoom-out => resize and pad + else: + image_shape_raw = image.shape + image_small = cv2.resize(image, (W_scale, H_scale), + interpolation=cv2.INTER_AREA) + # Decide the pasting location + h_start = round((H - H_scale) / 2) + w_start = round((W - W_scale) / 2) + # Paste the image to the middle + image = np.zeros(image_shape_raw, dtype=np.float) + image[h_start:h_start+H_scale, + w_start:w_start+W_scale, ...] = image_small + valid_mask = np.zeros([H, W], dtype=np.int) + valid_mask[h_start:h_start+H_scale, w_start:w_start+W_scale] = 1 + + # Process the junctions + junctions, line_map = process_junctions_and_line_map( + h_start, w_start, H, W, H_scale, W_scale, + junctions, line_map, "zoom-out") + + return image, junctions, line_map, valid_mask + + +def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale, + junctions, line_map, mode="zoom-in"): + if mode == "zoom-in": + junctions[:, 0] = junctions[:, 0] * H_scale / H + junctions[:, 1] = junctions[:, 1] * W_scale / W + line_segments = homoaug.convert_to_line_segments(junctions, line_map) + # Crop segments to the new boundaries + line_segments_new = np.zeros([0, 4]) + image_poly = sg.Polygon( + [[w_start, h_start], + [w_start+W, h_start], + [w_start+W, h_start+H], + [w_start, h_start+H] + ]) + for idx in range(line_segments.shape[0]): + # Get the line segment + seg_raw = line_segments[idx, :] # in HW format. + # Convert to shapely line (flip to xy format) + seg = sg.LineString([np.flip(seg_raw[:2]), + np.flip(seg_raw[2:])]) + # The line segment is just inside the image. + if seg.intersection(image_poly) == seg: + line_segments_new = np.concatenate( + (line_segments_new, seg_raw[None, ...]), axis=0) + # Intersect with the image. + elif seg.intersects(image_poly): + # Check intersection + try: + p = np.array( + seg.intersection(image_poly).coords).reshape([-1, 4]) + # If intersect at exact one point, just continue. + except: + continue + segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:], + axis=0)])[None, ...] + line_segments_new = np.concatenate( + (line_segments_new, segment), axis=0) + else: + continue + line_segments_new = (np.round(line_segments_new)).astype(np.int) + # Filter segments with 0 length + segment_lens = np.linalg.norm( + line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1) + seg_mask = segment_lens != 0 + line_segments_new = line_segments_new[seg_mask, :] + # Convert back to junctions and line_map + junctions_new = np.concatenate( + (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0) + if junctions_new.shape[0] == 0: + junctions_new = np.zeros([0, 2]) + line_map = np.zeros([0, 0]) + else: + junctions_new = np.unique(junctions_new, axis=0) + # Generate line map from points and segments + line_map = get_line_map(junctions_new, + line_segments_new).astype(np.int) + junctions_new[:, 0] -= h_start + junctions_new[:, 1] -= w_start + junctions = junctions_new + elif mode == "zoom-out": + # Process the junctions + junctions[:, 0] = (junctions[:, 0] * H_scale / H) + h_start + junctions[:, 1] = (junctions[:, 1] * W_scale / W) + w_start + else: + raise ValueError("[Error] unknown mode...") + + return junctions, line_map diff --git a/imcui/third_party/SOLD2/sold2/dataset/wireframe_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/wireframe_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ed5bb910bed1b89934ddaaec3bcddf111ea0faef --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/dataset/wireframe_dataset.py @@ -0,0 +1,1000 @@ +""" +This file implements the wireframe dataset object for pytorch. +Some parts of the code are adapted from https://github.com/zhou13/lcnn +""" +import os +import math +import copy +from skimage.io import imread +from skimage import color +import PIL +import numpy as np +import h5py +import cv2 +import pickle +import torch +import torch.utils.data.dataloader as torch_loader +from torch.utils.data import Dataset +from torchvision import transforms + +from ..config.project_config import Config as cfg +from .transforms import photometric_transforms as photoaug +from .transforms import homographic_transforms as homoaug +from .transforms.utils import random_scaling +from .synthetic_util import get_line_heatmap +from ..misc.train_utils import parse_h5_data +from ..misc.geometry_utils import warp_points, mask_points + + +def wireframe_collate_fn(batch): + """ Customized collate_fn for wireframe dataset. """ + batch_keys = ["image", "junction_map", "valid_mask", "heatmap", + "heatmap_pos", "heatmap_neg", "homography", + "line_points", "line_indices"] + list_keys = ["junctions", "line_map", "line_map_pos", + "line_map_neg", "file_key"] + + outputs = {} + for data_key in batch[0].keys(): + batch_match = sum([_ in data_key for _ in batch_keys]) + list_match = sum([_ in data_key for _ in list_keys]) + # print(batch_match, list_match) + if batch_match > 0 and list_match == 0: + outputs[data_key] = torch_loader.default_collate( + [b[data_key] for b in batch]) + elif batch_match == 0 and list_match > 0: + outputs[data_key] = [b[data_key] for b in batch] + elif batch_match == 0 and list_match == 0: + continue + else: + raise ValueError( + "[Error] A key matches batch keys and list keys simultaneously.") + + return outputs + + +class WireframeDataset(Dataset): + def __init__(self, mode="train", config=None): + super(WireframeDataset, self).__init__() + if not mode in ["train", "test"]: + raise ValueError( + "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'.") + self.mode = mode + + if config is None: + self.config = self.get_default_config() + else: + self.config = config + # Also get the default config + self.default_config = self.get_default_config() + + # Get cache setting + self.dataset_name = self.get_dataset_name() + self.cache_name = self.get_cache_name() + self.cache_path = cfg.wireframe_cache_path + + # Get the ground truth source + self.gt_source = self.config.get("gt_source_%s"%(self.mode), + "official") + if not self.gt_source == "official": + # Convert gt_source to full path + self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source) + # Check the full path exists + if not os.path.exists(self.gt_source): + raise ValueError( + "[Error] The specified ground truth source does not exist.") + + + # Get the filename dataset + print("[Info] Initializing wireframe dataset...") + self.filename_dataset, self.datapoints = self.construct_dataset() + + # Get dataset length + self.dataset_length = len(self.datapoints) + + # Print some info + print("[Info] Successfully initialized dataset") + print("\t Name: wireframe") + print("\t Mode: %s" %(self.mode)) + print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode), + "official"))) + print("\t Counts: %d" %(self.dataset_length)) + print("----------------------------------------") + + ####################################### + ## Dataset construction related APIs ## + ####################################### + def construct_dataset(self): + """ Construct the dataset (from scratch or from cache). """ + # Check if the filename cache exists + # If cache exists, load from cache + if self._check_dataset_cache(): + print("\t Found filename cache %s at %s"%(self.cache_name, + self.cache_path)) + print("\t Load filename cache...") + filename_dataset, datapoints = self.get_filename_dataset_from_cache() + # If not, initialize dataset from scratch + else: + print("\t Can't find filename cache ...") + print("\t Create filename dataset from scratch...") + filename_dataset, datapoints = self.get_filename_dataset() + print("\t Create filename dataset cache...") + self.create_filename_dataset_cache(filename_dataset, datapoints) + + return filename_dataset, datapoints + + def create_filename_dataset_cache(self, filename_dataset, datapoints): + """ Create filename dataset cache for faster initialization. """ + # Check cache path exists + if not os.path.exists(self.cache_path): + os.makedirs(self.cache_path) + + cache_file_path = os.path.join(self.cache_path, self.cache_name) + data = { + "filename_dataset": filename_dataset, + "datapoints": datapoints + } + with open(cache_file_path, "wb") as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + def get_filename_dataset_from_cache(self): + """ Get filename dataset from cache. """ + # Load from pkl cache + cache_file_path = os.path.join(self.cache_path, self.cache_name) + with open(cache_file_path, "rb") as f: + data = pickle.load(f) + + return data["filename_dataset"], data["datapoints"] + + def get_filename_dataset(self): + # Get the path to the dataset + if self.mode == "train": + dataset_path = os.path.join(cfg.wireframe_dataroot, "train") + elif self.mode == "test": + dataset_path = os.path.join(cfg.wireframe_dataroot, "valid") + + # Get paths to all image files + image_paths = sorted([os.path.join(dataset_path, _) + for _ in os.listdir(dataset_path)\ + if os.path.splitext(_)[-1] == ".png"]) + # Get the shared prefix + prefix_paths = [_.split(".png")[0] for _ in image_paths] + + # Get the label paths (different procedure for different split) + if self.mode == "train": + label_paths = [_ + "_label.npz" for _ in prefix_paths] + else: + label_paths = [_ + "_label.npz" for _ in prefix_paths] + mat_paths = [p[:-2] + "_line.mat" for p in prefix_paths] + + # Verify all the images and labels exist + for idx in range(len(image_paths)): + image_path = image_paths[idx] + label_path = label_paths[idx] + if (not (os.path.exists(image_path) + and os.path.exists(label_path))): + raise ValueError( + "[Error] The image and label do not exist. %s"%(image_path)) + # Further verify mat paths for test split + if self.mode == "test": + mat_path = mat_paths[idx] + if not os.path.exists(mat_path): + raise ValueError( + "[Error] The mat file does not exist. %s"%(mat_path)) + + # Construct the filename dataset + num_pad = int(math.ceil(math.log10(len(image_paths))) + 1) + filename_dataset = {} + for idx in range(len(image_paths)): + # Get the file key + key = self.get_padded_filename(num_pad, idx) + + filename_dataset[key] = { + "image": image_paths[idx], + "label": label_paths[idx] + } + + # Get the datapoints + datapoints = list(sorted(filename_dataset.keys())) + + return filename_dataset, datapoints + + def get_dataset_name(self): + """ Get dataset name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + + return dataset_name + + def get_cache_name(self): + """ Get cache name from dataset config / default config. """ + if self.config["dataset_name"] is None: + dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode + else: + dataset_name = self.config["dataset_name"] + "_%s" % self.mode + # Compose cache name + cache_name = dataset_name + "_cache.pkl" + + return cache_name + + @staticmethod + def get_padded_filename(num_pad, idx): + """ Get the padded filename using adaptive padding. """ + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + + return filename + + def get_default_config(self): + """ Get the default configuration. """ + return { + "dataset_name": "wireframe", + "add_augmentation_to_all_splits": False, + "preprocessing": { + "resize": [240, 320], + "blur_size": 11 + }, + "augmentation":{ + "photometric":{ + "enable": False + }, + "homographic":{ + "enable": False + }, + }, + } + + + ############################################ + ## Pytorch and preprocessing related APIs ## + ############################################ + # Get data from the information from filename dataset + @staticmethod + def get_data_from_path(data_path): + output = {} + + # Get image data + image_path = data_path["image"] + image = imread(image_path) + output["image"] = image + + # Get the npz label + """ Data entries in the npz file + jmap: [J, H, W] Junction heat map (H and W are 4x smaller) + joff: [J, 2, H, W] Junction offset within each pixel (Not sure about offsets) + lmap: [H, W] Line heat map with anti-aliasing (H and W are 4x smaller) + junc: [Na, 3] Junction coordinates (coordinates from 0~128 => 4x smaller.) + Lpos: [M, 2] Positive lines represented with junction indices + Lneg: [M, 2] Negative lines represented with junction indices + lpos: [Np, 2, 3] Positive lines represented with junction coordinates + lneg: [Nn, 2, 3] Negative lines represented with junction coordinates + """ + label_path = data_path["label"] + label = np.load(label_path) + for key in list(label.keys()): + output[key] = label[key] + + # If there's "line_mat" entry. + # TODO: How to process mat data + if data_path.get("line_mat") is not None: + raise NotImplementedError + + return output + + @staticmethod + def convert_line_map(lcnn_line_map, num_junctions): + """ Convert the line_pos or line_neg + (represented by two junction indexes) to our line map. """ + # Initialize empty line map + line_map = np.zeros([num_junctions, num_junctions]) + + # Iterate through all the lines + for idx in range(lcnn_line_map.shape[0]): + index1 = lcnn_line_map[idx, 0] + index2 = lcnn_line_map[idx, 1] + + line_map[index1, index2] = 1 + line_map[index2, index1] = 1 + + return line_map + + @staticmethod + def junc_to_junc_map(junctions, image_size): + """ Convert junction points to junction maps. """ + junctions = np.round(junctions).astype(np.int) + # Clip the boundary by image size + junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1) + junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1) + + # Create junction map + junc_map = np.zeros([image_size[0], image_size[1]]) + junc_map[junctions[:, 0], junctions[:, 1]] = 1 + + return junc_map[..., None].astype(np.int) + + def parse_transforms(self, names, all_transforms): + """ Parse the transform. """ + trans = all_transforms if (names == 'all') \ + else (names if isinstance(names, list) else [names]) + assert set(trans) <= set(all_transforms) + return trans + + def get_photo_transform(self): + """ Get list of photometric transforms (according to the config). """ + # Get the photometric transform config + photo_config = self.config["augmentation"]["photometric"] + if not photo_config["enable"]: + raise ValueError( + "[Error] Photometric augmentation is not enabled.") + + # Parse photometric transforms + trans_lst = self.parse_transforms(photo_config["primitives"], + photoaug.available_augmentations) + trans_config_lst = [photo_config["params"].get(p, {}) + for p in trans_lst] + + # List of photometric augmentation + photometric_trans_lst = [ + getattr(photoaug, trans)(**conf) \ + for (trans, conf) in zip(trans_lst, trans_config_lst) + ] + + return photometric_trans_lst + + def get_homo_transform(self): + """ Get homographic transforms (according to the config). """ + # Get homographic transforms for image + homo_config = self.config["augmentation"]["homographic"]["params"] + if not self.config["augmentation"]["homographic"]["enable"]: + raise ValueError( + "[Error] Homographic augmentation is not enabled.") + + # Parse the homographic transforms + image_shape = self.config["preprocessing"]["resize"] + + # Compute the min_label_len from config + try: + min_label_tmp = self.config["generation"]["min_label_len"] + except: + min_label_tmp = None + + # float label len => fraction + if isinstance(min_label_tmp, float): # Skip if not provided + min_label_len = min_label_tmp * min(image_shape) + # int label len => length in pixel + elif isinstance(min_label_tmp, int): + scale_ratio = (self.config["preprocessing"]["resize"] + / self.config["generation"]["image_size"][0]) + min_label_len = (self.config["generation"]["min_label_len"] + * scale_ratio) + # if none => no restriction + else: + min_label_len = 0 + + # Initialize the transform + homographic_trans = homoaug.homography_transform( + image_shape, homo_config, 0, min_label_len) + + return homographic_trans + + def get_line_points(self, junctions, line_map, H1=None, H2=None, + img_size=None, warp=False): + """ Sample evenly points along each line segments + and keep track of line idx. """ + if np.sum(line_map) == 0: + # No segment detected in the image + line_indices = np.zeros(self.config["max_pts"], dtype=int) + line_points = np.zeros((self.config["max_pts"], 2), dtype=float) + return line_points, line_indices + + # Extract all pairs of connected junctions + junc_indices = np.array( + [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i]) + line_segments = np.stack([junctions[junc_indices[:, 0]], + junctions[junc_indices[:, 1]]], axis=1) + # line_segments is (num_lines, 2, 2) + line_lengths = np.linalg.norm( + line_segments[:, 0] - line_segments[:, 1], axis=1) + + # Sample the points separated by at least min_dist_pts along each line + # The number of samples depends on the length of the line + num_samples = np.minimum(line_lengths // self.config["min_dist_pts"], + self.config["max_num_samples"]) + line_points = [] + line_indices = [] + cur_line_idx = 1 + for n in np.arange(2, self.config["max_num_samples"] + 1): + # Consider all lines where we can fit up to n points + cur_line_seg = line_segments[num_samples == n] + line_points_x = np.linspace(cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 0], + n, axis=-1).flatten() + line_points_y = np.linspace(cur_line_seg[:, 0, 1], + cur_line_seg[:, 1, 1], + n, axis=-1).flatten() + jitter = self.config.get("jittering", 0) + if jitter: + # Add a small random jittering of all points along the line + angles = np.arctan2( + cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n) + jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter + line_points_x += jitter_hyp * np.sin(angles) + line_points_y += jitter_hyp * np.cos(angles) + line_points.append(np.stack([line_points_x, line_points_y], axis=-1)) + # Keep track of the line indices for each sampled point + num_cur_lines = len(cur_line_seg) + line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines) + line_indices.append(line_idx.repeat(n)) + cur_line_idx += num_cur_lines + line_points = np.concatenate(line_points, + axis=0)[:self.config["max_pts"]] + line_indices = np.concatenate(line_indices, + axis=0)[:self.config["max_pts"]] + + # Warp the points if need be, and filter unvalid ones + # If the other view is also warped + if warp and H2 is not None: + warp_points2 = warp_points(line_points, H2) + line_points = warp_points(line_points, H1) + mask = mask_points(line_points, img_size) + mask2 = mask_points(warp_points2, img_size) + mask = mask * mask2 + # If the other view is not warped + elif warp and H2 is None: + line_points = warp_points(line_points, H1) + mask = mask_points(line_points, img_size) + else: + if H1 is not None: + raise ValueError("[Error] Wrong combination of homographies.") + # Remove points that would be outside of img_size if warped by H + warped_points = warp_points(line_points, H1) + mask = mask_points(warped_points, img_size) + line_points = line_points[mask] + line_indices = line_indices[mask] + + # Pad the line points to a fixed length + # Index of 0 means padded line + line_indices = np.concatenate([line_indices, np.zeros( + self.config["max_pts"] - len(line_indices))], axis=0) + line_points = np.concatenate( + [line_points, + np.zeros((self.config["max_pts"] - len(line_points), 2), + dtype=float)], axis=0) + + return line_points, line_indices + + def train_preprocessing(self, data, numpy=False): + """ Train preprocessing for GT data. """ + # Fetch the corresponding entries + image = data["image"] + junctions = data["junc"][:, :2] + line_pos = data["Lpos"] + line_neg = data["Lneg"] + image_size = image.shape[:2] + # Convert junctions to pixel coordinates (from 128x128) + junctions[:, 0] *= image_size[0] / 128 + junctions[:, 1] *= image_size[1] / 128 + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # In HW format + junctions = (junctions * np.array( + self.config['preprocessing']['resize'], np.float) + / np.array(size_old, np.float)) + + # Convert to positive line map and negative line map (our format) + num_junctions = junctions.shape[0] + line_map_pos = self.convert_line_map(line_pos, num_junctions) + line_map_neg = self.convert_line_map(line_neg, num_junctions) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + # Update image size + image_size = image.shape[:2] + heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size) + heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size) + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Check if we need to apply augmentations + # In training mode => yes. + # In homography adaptation mode (export mode) => No + if self.config["augmentation"]["photometric"]["enable"]: + photo_trans_lst = self.get_photo_transform() + ### Image transform ### + np.random.shuffle(photo_trans_lst) + image_transform = transforms.Compose( + photo_trans_lst + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Check homographic augmentation + if self.config["augmentation"]["homographic"]["enable"]: + homo_trans = self.get_homo_transform() + # Perform homographic transform + outputs_pos = homo_trans(image, junctions, line_map_pos) + outputs_neg = homo_trans(image, junctions, line_map_neg) + + # record the warped results + junctions = outputs_pos["junctions"] # Should be HW format + image = outputs_pos["warped_image"] + line_map_pos = outputs_pos["line_map"] + line_map_neg = outputs_neg["line_map"] + heatmap_pos = outputs_pos["warped_heatmap"] + heatmap_neg = outputs_neg["warped_heatmap"] + valid_mask = outputs_pos["valid_mask"] # Same for pos and neg + + junction_map = self.junc_to_junc_map(junctions, image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + return { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map_pos": to_tensor( + line_map_pos).to(torch.int32)[0, ...], + "line_map_neg": to_tensor( + line_map_neg).to(torch.int32)[0, ...], + "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32), + "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + return { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map_pos": line_map_pos.astype(np.int32), + "line_map_neg": line_map_neg.astype(np.int32), + "heatmap_pos": heatmap_pos.astype(np.int32), + "heatmap_neg": heatmap_neg.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + } + + def train_preprocessing_exported( + self, data, numpy=False, disable_homoaug=False, + desc_training=False, H1=None, H1_scale=None, H2=None, scale=1., + h_crop=None, w_crop=None): + """ Train preprocessing for the exported labels. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junctions"] + line_map = data["line_map"] + image_size = image.shape[:2] + + # Define the random crop for scaling if necessary + if h_crop is None or w_crop is None: + h_crop, w_crop = 0, 0 + if scale > 1: + H, W = self.config["preprocessing"]["resize"] + H_scale, W_scale = round(H * scale), round(W * scale) + if H_scale > H: + h_crop = np.random.randint(H_scale - H) + if W_scale > W: + w_crop = np.random.randint(W_scale - W) + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # # In HW format + # junctions = (junctions * np.array( + # self.config['preprocessing']['resize'], np.float) + # / np.array(size_old, np.float)) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + image_size = image.shape[:2] + heatmap = get_line_heatmap(junctions_xy, line_map, image_size) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Check if we need to apply augmentations + # In training mode => yes. + # In homography adaptation mode (export mode) => No + if self.config["augmentation"]["photometric"]["enable"]: + photo_trans_lst = self.get_photo_transform() + ### Image transform ### + np.random.shuffle(photo_trans_lst) + image_transform = transforms.Compose( + photo_trans_lst + [photoaug.normalize_image()]) + else: + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Perform the random scaling + if scale != 1.: + image, junctions, line_map, valid_mask = random_scaling( + image, junctions, line_map, scale, + h_crop=h_crop, w_crop=w_crop) + else: + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + # Initialize the empty output dict + outputs = {} + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + + # Check homographic augmentation + warp = (self.config["augmentation"]["homographic"]["enable"] + and disable_homoaug == False) + if warp: + homo_trans = self.get_homo_transform() + # Perform homographic transform + if H1 is None: + homo_outputs = homo_trans( + image, junctions, line_map, valid_mask=valid_mask) + else: + homo_outputs = homo_trans( + image, junctions, line_map, homo=H1, scale=H1_scale, + valid_mask=valid_mask) + homography_mat = homo_outputs["homo"] + + # Give the warp of the other view + if H1 is None: + H1 = homo_outputs["homo"] + + # Sample points along each line segments for the descriptor + if desc_training: + line_points, line_indices = self.get_line_points( + junctions, line_map, H1=H1, H2=H2, + img_size=image_size, warp=warp) + + # Record the warped results + if warp: + junctions = homo_outputs["junctions"] # Should be HW format + image = homo_outputs["warped_image"] + line_map = homo_outputs["line_map"] + valid_mask = homo_outputs["valid_mask"] # Same for pos and neg + heatmap = homo_outputs["warped_heatmap"] + + # Optionally put warping information first. + if not numpy: + outputs["homography_mat"] = to_tensor( + homography_mat).to(torch.float32)[0, ...] + else: + outputs["homography_mat"] = homography_mat.astype(np.float32) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + if not numpy: + outputs.update({ + "image": to_tensor(image).to(torch.float32), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + }) + if desc_training: + outputs.update({ + "line_points": to_tensor( + line_points).to(torch.float32)[0], + "line_indices": torch.tensor(line_indices, + dtype=torch.int) + }) + else: + outputs.update({ + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + }) + if desc_training: + outputs.update({ + "line_points": line_points.astype(np.float32), + "line_indices": line_indices.astype(int) + }) + + return outputs + + def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.): + """ Train preprocessing for paired data for the exported labels + for descriptor training. """ + outputs = {} + + # Define the random crop for scaling if necessary + h_crop, w_crop = 0, 0 + if scale > 1: + H, W = self.config["preprocessing"]["resize"] + H_scale, W_scale = round(H * scale), round(W * scale) + if H_scale > H: + h_crop = np.random.randint(H_scale - H) + if W_scale > W: + w_crop = np.random.randint(W_scale - W) + + # Sample ref homography first + homo_config = self.config["augmentation"]["homographic"]["params"] + image_shape = self.config["preprocessing"]["resize"] + ref_H, ref_scale = homoaug.sample_homography(image_shape, + **homo_config) + + # Data for target view (All augmentation) + target_data = self.train_preprocessing_exported( + data, numpy=numpy, desc_training=True, H1=None, H2=ref_H, + scale=scale, h_crop=h_crop, w_crop=w_crop) + + # Data for reference view (No homographical augmentation) + ref_data = self.train_preprocessing_exported( + data, numpy=numpy, desc_training=True, H1=ref_H, + H1_scale=ref_scale, H2=target_data["homography_mat"].numpy(), + scale=scale, h_crop=h_crop, w_crop=w_crop) + + # Spread ref data + for key, val in ref_data.items(): + outputs["ref_" + key] = val + + # Spread target data + for key, val in target_data.items(): + outputs["target_" + key] = val + + return outputs + + def test_preprocessing(self, data, numpy=False): + """ Test preprocessing for GT data. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junc"][:, :2] + line_pos = data["Lpos"] + line_neg = data["Lneg"] + image_size = image.shape[:2] + # Convert junctions to pixel coordinates (from 128x128) + junctions[:, 0] *= image_size[0] / 128 + junctions[:, 1] *= image_size[1] / 128 + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # In HW format + junctions = (junctions * np.array( + self.config['preprocessing']['resize'], np.float) + / np.array(size_old, np.float)) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Still need to normalize image + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Convert to positive line map and negative line map (our format) + num_junctions = junctions.shape[0] + line_map_pos = self.convert_line_map(line_pos, num_junctions) + line_map_neg = self.convert_line_map(line_neg, num_junctions) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + # Update image size + image_size = image.shape[:2] + heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size) + heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size) + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + return { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map_pos": to_tensor( + line_map_pos).to(torch.int32)[0, ...], + "line_map_neg": to_tensor( + line_map_neg).to(torch.int32)[0, ...], + "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32), + "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + return { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map_pos": line_map_pos.astype(np.int32), + "line_map_neg": line_map_neg.astype(np.int32), + "heatmap_pos": heatmap_pos.astype(np.int32), + "heatmap_neg": heatmap_neg.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + } + + def test_preprocessing_exported(self, data, numpy=False, scale=1.): + """ Test preprocessing for the exported labels. """ + data = copy.deepcopy(data) + # Fetch the corresponding entries + image = data["image"] + junctions = data["junctions"] + line_map = data["line_map"] + image_size = image.shape[:2] + + # Resize the image before photometric and homographical augmentations + if not(list(image_size) == self.config["preprocessing"]["resize"]): + # Resize the image and the point location. + size_old = list(image.shape)[:2] # Only H and W dimensions + + image = cv2.resize( + image, tuple(self.config['preprocessing']['resize'][::-1]), + interpolation=cv2.INTER_LINEAR) + image = np.array(image, dtype=np.uint8) + + # # In HW format + # junctions = (junctions * np.array( + # self.config['preprocessing']['resize'], np.float) + # / np.array(size_old, np.float)) + + # Optionally convert the image to grayscale + if self.config["gray_scale"]: + image = (color.rgb2gray(image) * 255.).astype(np.uint8) + + # Still need to normalize image + image_transform = photoaug.normalize_image() + image = image_transform(image) + + # Generate the line heatmap after post-processing + junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1) + image_size = image.shape[:2] + heatmap = get_line_heatmap(junctions_xy, line_map, image_size) + + # Declare default valid mask (all ones) + valid_mask = np.ones(image_size) + + junction_map = self.junc_to_junc_map(junctions, image_size) + + # Convert to tensor and return the results + to_tensor = transforms.ToTensor() + if not numpy: + outputs = { + "image": to_tensor(image), + "junctions": to_tensor(junctions).to(torch.float32)[0, ...], + "junction_map": to_tensor(junction_map).to(torch.int), + "line_map": to_tensor(line_map).to(torch.int32)[0, ...], + "heatmap": to_tensor(heatmap).to(torch.int32), + "valid_mask": to_tensor(valid_mask).to(torch.int32) + } + else: + outputs = { + "image": image, + "junctions": junctions.astype(np.float32), + "junction_map": junction_map.astype(np.int32), + "line_map": line_map.astype(np.int32), + "heatmap": heatmap.astype(np.int32), + "valid_mask": valid_mask.astype(np.int32) + } + + return outputs + + def __len__(self): + return self.dataset_length + + def get_data_from_key(self, file_key): + """ Get data from file_key. """ + # Check key exists + if not file_key in self.filename_dataset.keys(): + raise ValueError("[Error] the specified key is not in the dataset.") + + # Get the data paths + data_path = self.filename_dataset[file_key] + # Read in the image and npz labels (but haven't applied any transform) + data = self.get_data_from_path(data_path) + + # Perform transform and augmentation + if self.mode == "train" or self.config["add_augmentation_to_all_splits"]: + data = self.train_preprocessing(data, numpy=True) + else: + data = self.test_preprocessing(data, numpy=True) + + # Add file key to the output + data["file_key"] = file_key + + return data + + def __getitem__(self, idx): + """Return data + file_key: str, keys used to retrieve data from the filename dataset. + image: torch.float, C*H*W range 0~1, + junctions: torch.float, N*2, + junction_map: torch.int32, 1*H*W range 0 or 1, + line_map_pos: torch.int32, N*N range 0 or 1, + line_map_neg: torch.int32, N*N range 0 or 1, + heatmap_pos: torch.int32, 1*H*W range 0 or 1, + heatmap_neg: torch.int32, 1*H*W range 0 or 1, + valid_mask: torch.int32, 1*H*W range 0 or 1 + """ + # Get the corresponding datapoint and contents from filename dataset + file_key = self.datapoints[idx] + data_path = self.filename_dataset[file_key] + # Read in the image and npz labels (but haven't applied any transform) + data = self.get_data_from_path(data_path) + + # Also load the exported labels if not using the official ground truth + if not self.gt_source == "official": + with h5py.File(self.gt_source, "r") as f: + exported_label = parse_h5_data(f[file_key]) + + data["junctions"] = exported_label["junctions"] + data["line_map"] = exported_label["line_map"] + + # Perform transform and augmentation + return_type = self.config.get("return_type", "single") + if (self.mode == "train" + or self.config["add_augmentation_to_all_splits"]): + # Perform random scaling first + if self.config["augmentation"]["random_scaling"]["enable"]: + scale_range = self.config["augmentation"]["random_scaling"]["range"] + # Decide the scaling + scale = np.random.uniform(min(scale_range), max(scale_range)) + else: + scale = 1. + if self.gt_source == "official": + data = self.train_preprocessing(data) + else: + if return_type == "paired_desc": + data = self.preprocessing_exported_paired_desc( + data, scale=scale) + else: + data = self.train_preprocessing_exported(data, + scale=scale) + else: + if self.gt_source == "official": + data = self.test_preprocessing(data) + elif return_type == "paired_desc": + data = self.preprocessing_exported_paired_desc(data) + else: + data = self.test_preprocessing_exported(data) + + # Add file key to the output + data["file_key"] = file_key + + return data + + ######################## + ## Some other methods ## + ######################## + def _check_dataset_cache(self): + """ Check if dataset cache exists. """ + cache_file_path = os.path.join(self.cache_path, self.cache_name) + if os.path.exists(cache_file_path): + return True + else: + return False diff --git a/imcui/third_party/SOLD2/sold2/experiment.py b/imcui/third_party/SOLD2/sold2/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf4db1c9f148b9e33c6d7d0ba973375cd770a14 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/experiment.py @@ -0,0 +1,227 @@ +""" +Main file to launch training and testing experiments. +""" + +import yaml +import os +import argparse +import numpy as np +import torch + +from .config.project_config import Config as cfg +from .train import train_net +from .export import export_predictions, export_homograpy_adaptation + + +# Pytorch configurations +torch.cuda.empty_cache() +torch.backends.cudnn.benchmark = True + + +def load_config(config_path): + """ Load configurations from a given yaml file. """ + # Check file exists + if not os.path.exists(config_path): + raise ValueError("[Error] The provided config path is not valid.") + + # Load the configuration + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + return config + + +def update_config(path, model_cfg=None, dataset_cfg=None): + """ Update configuration file from the resume path. """ + # Check we need to update or completely override. + model_cfg = {} if model_cfg is None else model_cfg + dataset_cfg = {} if dataset_cfg is None else dataset_cfg + + # Load saved configs + with open(os.path.join(path, "model_cfg.yaml"), "r") as f: + model_cfg_saved = yaml.safe_load(f) + model_cfg.update(model_cfg_saved) + with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f: + dataset_cfg_saved = yaml.safe_load(f) + dataset_cfg.update(dataset_cfg_saved) + + # Update the saved yaml file + if not model_cfg == model_cfg_saved: + with open(os.path.join(path, "model_cfg.yaml"), "w") as f: + yaml.dump(model_cfg, f) + if not dataset_cfg == dataset_cfg_saved: + with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f: + yaml.dump(dataset_cfg, f) + + return model_cfg, dataset_cfg + + +def record_config(model_cfg, dataset_cfg, output_path): + """ Record dataset config to the log path. """ + # Record model config + with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: + yaml.safe_dump(model_cfg, f) + + # Record dataset config + with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: + yaml.safe_dump(dataset_cfg, f) + + +def train(args, dataset_cfg, model_cfg, output_path): + """ Training function. """ + # Update model config from the resume path (only in resume mode) + if args.resume: + if os.path.realpath(output_path) != os.path.realpath(args.resume_path): + record_config(model_cfg, dataset_cfg, output_path) + + # First time, then write the config file to the output path + else: + record_config(model_cfg, dataset_cfg, output_path) + + # Launch the training + train_net(args, dataset_cfg, model_cfg, output_path) + + +def export(args, dataset_cfg, model_cfg, output_path, + export_dataset_mode=None, device=torch.device("cuda")): + """ Export function. """ + # Choose between normal predictions export or homography adaptation + if dataset_cfg.get("homography_adaptation") is not None: + print("[Info] Export predictions with homography adaptation.") + export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path, + export_dataset_mode, device) + else: + print("[Info] Export predictions normally.") + export_predictions(args, dataset_cfg, model_cfg, output_path, + export_dataset_mode) + + +def main(args, dataset_cfg, model_cfg, export_dataset_mode=None, + device=torch.device("cuda")): + """ Main function. """ + # Make the output path + output_path = os.path.join(cfg.EXP_PATH, args.exp_name) + + if args.mode == "train": + if not os.path.exists(output_path): + os.makedirs(output_path) + print("[Info] Training mode") + print("\t Output path: %s" % output_path) + train(args, dataset_cfg, model_cfg, output_path) + elif args.mode == "export": + # Different output_path in export mode + output_path = os.path.join(cfg.export_dataroot, args.exp_name) + print("[Info] Export mode") + print("\t Output path: %s" % output_path) + export(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device=device) + else: + raise ValueError("[Error]: Unknown mode: " + args.mode) + + +def set_random_seed(seed): + np.random.seed(seed) + torch.manual_seed(seed) + + +if __name__ == "__main__": + # Parse input arguments + parser = argparse.ArgumentParser() + parser.add_argument("--mode", type=str, default="train", + help="'train' or 'export'.") + parser.add_argument("--dataset_config", type=str, default=None, + help="Path to the dataset config.") + parser.add_argument("--model_config", type=str, default=None, + help="Path to the model config.") + parser.add_argument("--exp_name", type=str, default="exp", + help="Experiment name.") + parser.add_argument("--resume", action="store_true", default=False, + help="Load a previously trained model.") + parser.add_argument("--pretrained", action="store_true", default=False, + help="Start training from a pre-trained model.") + parser.add_argument("--resume_path", default=None, + help="Path from which to resume training.") + parser.add_argument("--pretrained_path", default=None, + help="Path to the pre-trained model.") + parser.add_argument("--checkpoint_name", default=None, + help="Name of the checkpoint to use.") + parser.add_argument("--export_dataset_mode", default=None, + help="'train' or 'test'.") + parser.add_argument("--export_batch_size", default=4, type=int, + help="Export batch size.") + + args = parser.parse_args() + + # Check if GPU is available + # Get the model + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + # Check if dataset config and model config is given. + if (((args.dataset_config is None) or (args.model_config is None)) + and (not args.resume) and (args.mode == "train")): + raise ValueError( + "[Error] The dataset config and model config should be given in non-resume mode") + + # If resume, check if the resume path has been given + if args.resume and (args.resume_path is None): + raise ValueError( + "[Error] Missing resume path.") + + # [Training] Load the config file. + if args.mode == "train" and (not args.resume): + # Check the pretrained checkpoint_path exists + if args.pretrained: + checkpoint_folder = args.resume_path + checkpoint_path = os.path.join(args.pretrained_path, + args.checkpoint_name) + if not os.path.exists(checkpoint_path): + raise ValueError("[Error] Missing checkpoint: " + + checkpoint_path) + dataset_cfg = load_config(args.dataset_config) + model_cfg = load_config(args.model_config) + + # [resume Training, Test, Export] Load the config file. + elif (args.mode == "train" and args.resume) or (args.mode == "export"): + # Check checkpoint path exists + checkpoint_folder = args.resume_path + checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name) + if not os.path.exists(checkpoint_path): + raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) + + # Load model_cfg from checkpoint folder if not provided + if args.model_config is None: + print("[Info] No model config provided. Loading from checkpoint folder.") + model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml") + if not os.path.exists(model_cfg_path): + raise ValueError( + "[Error] Missing model config in checkpoint path.") + model_cfg = load_config(model_cfg_path) + else: + model_cfg = load_config(args.model_config) + + # Load dataset_cfg from checkpoint folder if not provided + if args.dataset_config is None: + print("[Info] No dataset config provided. Loading from checkpoint folder.") + dataset_cfg_path = os.path.join(checkpoint_folder, + "dataset_cfg.yaml") + if not os.path.exists(dataset_cfg_path): + raise ValueError( + "[Error] Missing dataset config in checkpoint path.") + dataset_cfg = load_config(dataset_cfg_path) + else: + dataset_cfg = load_config(args.dataset_config) + + # Check the --export_dataset_mode flag + if (args.mode == "export") and (args.export_dataset_mode is None): + raise ValueError("[Error] Empty --export_dataset_mode flag.") + else: + raise ValueError("[Error] Unknown mode: " + args.mode) + + # Set the random seed + seed = dataset_cfg.get("random_seed", 0) + set_random_seed(seed) + + main(args, dataset_cfg, model_cfg, + export_dataset_mode=args.export_dataset_mode, device=device) diff --git a/imcui/third_party/SOLD2/sold2/export.py b/imcui/third_party/SOLD2/sold2/export.py new file mode 100644 index 0000000000000000000000000000000000000000..19683d982c6d7fd429b27868b620fd20562d1aa7 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/export.py @@ -0,0 +1,342 @@ +import numpy as np +import copy +import cv2 +import h5py +import math +from tqdm import tqdm +import torch +from torch.nn.functional import pixel_shuffle, softmax +from torch.utils.data import DataLoader +from kornia.geometry import warp_perspective + +from .dataset.dataset_util import get_dataset +from .model.model_util import get_model +from .misc.train_utils import get_latest_checkpoint +from .train import convert_junc_predictions +from .dataset.transforms.homographic_transforms import sample_homography + + +def restore_weights(model, state_dict): + """ Restore weights in compatible mode. """ + # Try to directly load state dict + try: + model.load_state_dict(state_dict) + except: + err = model.load_state_dict(state_dict, strict=False) + # missing keys are those in model but not in state_dict + missing_keys = err.missing_keys + # Unexpected keys are those in state_dict but not in model + unexpected_keys = err.unexpected_keys + + # Load mismatched keys manually + model_dict = model.state_dict() + for idx, key in enumerate(missing_keys): + dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] + model_dict[key] = state_dict[dict_keys[idx]] + model.load_state_dict(model_dict) + return model + + +def get_padded_filename(num_pad, idx): + """ Get the filename padded with 0. """ + file_len = len("%d" % (idx)) + filename = "0" * (num_pad - file_len) + "%d" % (idx) + return filename + + +def export_predictions(args, dataset_cfg, model_cfg, output_path, + export_dataset_mode): + """ Export predictions. """ + # Get the test configuration + test_cfg = model_cfg["test"] + + # Create the dataset and dataloader based on the export_dataset_mode + print("\t Initializing dataset and dataloader") + batch_size = 4 + export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) + export_loader = DataLoader(export_dataset, batch_size=batch_size, + num_workers=test_cfg.get("num_workers", 4), + shuffle=False, pin_memory=False, + collate_fn=collate_fn) + print("\t Successfully intialized dataset and dataloader.") + + # Initialize model and load the checkpoint + model = get_model(model_cfg, mode="test") + checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) + model = restore_weights(model, checkpoint["model_state_dict"]) + model = model.cuda() + model.eval() + print("\t Successfully initialized model") + + # Start the export process + print("[Info] Start exporting predictions") + output_dataset_path = output_path + ".h5" + filename_idx = 0 + with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f: + # Iterate through all the data in dataloader + for data in tqdm(export_loader, ascii=True): + # Fetch the data + junc_map = data["junction_map"] + heatmap = data["heatmap"] + valid_mask = data["valid_mask"] + input_images = data["image"].cuda() + + # Run the forward pass + with torch.no_grad(): + outputs = model(input_images) + + # Convert predictions + junc_np = convert_junc_predictions( + outputs["junctions"], model_cfg["grid_size"], + model_cfg["detection_thresh"], 300) + junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1) + heatmap_np = softmax(outputs["heatmap"].detach(), + dim=1).cpu().numpy().transpose(0, 2, 3, 1) + heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1) + valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1) + + # Data entries to save + current_batch_size = input_images.shape[0] + for batch_idx in range(current_batch_size): + output_data = { + "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "junc_gt": junc_map_np[batch_idx], + "junc_pred": junc_np["junc_pred"][batch_idx], + "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(np.float32), + "heatmap_gt": heatmap_gt_np[batch_idx], + "heatmap_pred": heatmap_np[batch_idx], + "valid_mask": valid_mask_np[batch_idx], + "junc_points": data["junctions"][batch_idx].numpy()[0].round().astype(np.int32), + "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32) + } + + # Save data to h5 dataset + num_pad = math.ceil(math.log10(len(export_loader))) + 1 + output_key = get_padded_filename(num_pad, filename_idx) + f_group = f.create_group(output_key) + + # Store data + for key, output_data in output_data.items(): + f_group.create_dataset(key, data=output_data, + compression="gzip") + filename_idx += 1 + + +def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path, + export_dataset_mode, device): + """ Export homography adaptation results. """ + # Check if the export_dataset_mode is supported + supported_modes = ["train", "test"] + if not export_dataset_mode in supported_modes: + raise ValueError( + "[Error] The specified export_dataset_mode is not supported.") + + # Get the test configuration + test_cfg = model_cfg["test"] + + # Get the homography adaptation configurations + homography_cfg = dataset_cfg.get("homography_adaptation", None) + if homography_cfg is None: + raise ValueError( + "[Error] Empty homography_adaptation entry in config.") + + # Create the dataset and dataloader based on the export_dataset_mode + print("\t Initializing dataset and dataloader") + batch_size = args.export_batch_size + + export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) + export_loader = DataLoader(export_dataset, batch_size=batch_size, + num_workers=test_cfg.get("num_workers", 4), + shuffle=False, pin_memory=False, + collate_fn=collate_fn) + print("\t Successfully intialized dataset and dataloader.") + + # Initialize model and load the checkpoint + model = get_model(model_cfg, mode="test") + checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, + device) + model = restore_weights(model, checkpoint["model_state_dict"]) + model = model.to(device).eval() + print("\t Successfully initialized model") + + # Start the export process + print("[Info] Start exporting predictions") + output_dataset_path = output_path + ".h5" + with h5py.File(output_dataset_path, "w", libver="latest") as f: + f.swmr_mode=True + for _, data in enumerate(tqdm(export_loader, ascii=True)): + input_images = data["image"].to(device) + file_keys = data["file_key"] + batch_size = input_images.shape[0] + + # Run the homograpy adaptation + outputs = homography_adaptation(input_images, model, + model_cfg["grid_size"], + homography_cfg) + + # Save the entries + for batch_idx in range(batch_size): + # Get the save key + save_key = file_keys[batch_idx] + output_data = { + "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "junc_prob_mean": outputs["junc_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "junc_prob_max": outputs["junc_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "junc_count": outputs["junc_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "heatmap_prob_mean": outputs["heatmap_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "heatmap_prob_max": outputs["heatmap_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx], + "heatmap_cout": outputs["heatmap_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx] + } + + # Create group and write data + f_group = f.create_group(save_key) + for key, output_data in output_data.items(): + f_group.create_dataset(key, data=output_data, + compression="gzip") + + +def homography_adaptation(input_images, model, grid_size, homography_cfg): + """ The homography adaptation process. + Arguments: + input_images: The images to be evaluated. + model: The pytorch model in evaluation mode. + grid_size: Grid size of the junction decoder. + homography_cfg: Homography adaptation configurations. + """ + # Get the device of the current model + device = next(model.parameters()).device + + # Define some constants and placeholder + batch_size, _, H, W = input_images.shape + num_iter = homography_cfg["num_iter"] + junc_probs = torch.zeros([batch_size, num_iter, H, W], device=device) + junc_counts = torch.zeros([batch_size, 1, H, W], device=device) + heatmap_probs = torch.zeros([batch_size, num_iter, H, W], device=device) + heatmap_counts = torch.zeros([batch_size, 1, H, W], device=device) + margin = homography_cfg["valid_border_margin"] + + # Keep a config with no artifacts + homography_cfg_no_artifacts = copy.copy(homography_cfg["homographies"]) + homography_cfg_no_artifacts["allow_artifacts"] = False + + for idx in range(num_iter): + if idx <= num_iter // 5: + # Ensure that 20% of the homographies have no artifact + H_mat_lst = [sample_homography( + [H,W], **homography_cfg_no_artifacts)[0][None] + for _ in range(batch_size)] + else: + H_mat_lst = [sample_homography( + [H,W], **homography_cfg["homographies"])[0][None] + for _ in range(batch_size)] + + H_mats = np.concatenate(H_mat_lst, axis=0) + H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device) + H_inv_tensor = torch.inverse(H_tensor) + + # Perform the homography warp + images_warped = warp_perspective(input_images, H_tensor, (H, W), + flags="bilinear") + + # Warp the mask + masks_junc_warped = warp_perspective( + torch.ones([batch_size, 1, H, W], device=device), + H_tensor, (H, W), flags="nearest") + masks_heatmap_warped = warp_perspective( + torch.ones([batch_size, 1, H, W], device=device), + H_tensor, (H, W), flags="nearest") + + # Run the network forward pass + with torch.no_grad(): + outputs = model(images_warped) + + # Unwarp and mask the junction prediction + junc_prob_warped = pixel_shuffle(softmax( + outputs["junctions"], dim=1)[:, :-1, :, :], grid_size) + junc_prob = warp_perspective(junc_prob_warped, H_inv_tensor, + (H, W), flags="bilinear") + + # Create the out of boundary mask + out_boundary_mask = warp_perspective( + torch.ones([batch_size, 1, H, W], device=device), + H_inv_tensor, (H, W), flags="nearest") + out_boundary_mask = adjust_border(out_boundary_mask, device, margin) + + junc_prob = junc_prob * out_boundary_mask + junc_count = warp_perspective(masks_junc_warped * out_boundary_mask, + H_inv_tensor, (H, W), flags="nearest") + + # Unwarp the mask and heatmap prediction + # Always fetch only one channel + if outputs["heatmap"].shape[1] == 2: + # Convert to single channel directly from here + heatmap_prob_warped = softmax(outputs["heatmap"], + dim=1)[:, 1:, :, :] + else: + heatmap_prob_warped = torch.sigmoid(outputs["heatmap"]) + + heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped + heatmap_prob = warp_perspective(heatmap_prob_warped, H_inv_tensor, + (H, W), flags="bilinear") + heatmap_count = warp_perspective(masks_heatmap_warped, H_inv_tensor, + (H, W), flags="nearest") + + # Record the results + junc_probs[:, idx:idx+1, :, :] = junc_prob + heatmap_probs[:, idx:idx+1, :, :] = heatmap_prob + junc_counts += junc_count + heatmap_counts += heatmap_count + + # Perform the accumulation operation + if homography_cfg["min_counts"] > 0: + min_counts = homography_cfg["min_counts"] + junc_count_mask = (junc_counts < min_counts) + heatmap_count_mask = (heatmap_counts < min_counts) + junc_counts[junc_count_mask] = 0 + heatmap_counts[heatmap_count_mask] = 0 + else: + junc_count_mask = np.zeros_like(junc_counts, dtype=bool) + heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool) + + # Compute the mean accumulation + junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts + junc_probs_mean[junc_count_mask] = 0. + heatmap_probs_mean = (torch.sum(heatmap_probs, dim=1, keepdim=True) + / heatmap_counts) + heatmap_probs_mean[heatmap_count_mask] = 0. + + # Compute the max accumulation + junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0] + junc_probs_max[junc_count_mask] = 0. + heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0] + heatmap_probs_max[heatmap_count_mask] = 0. + + return {"junc_probs_mean": junc_probs_mean, + "junc_probs_max": junc_probs_max, + "junc_counts": junc_counts, + "heatmap_probs_mean": heatmap_probs_mean, + "heatmap_probs_max": heatmap_probs_max, + "heatmap_counts": heatmap_counts} + + +def adjust_border(input_masks, device, margin=3): + """ Adjust the border of the counts and valid_mask. """ + # Convert the mask to numpy array + dtype = input_masks.dtype + input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1) + + erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (margin*2, margin*2)) + batch_size = input_masks.shape[0] + + output_mask_lst = [] + # Erode all the masks + for i in range(batch_size): + output_mask = cv2.erode(input_masks[i, ...], erosion_kernel) + + output_mask_lst.append( + torch.tensor(output_mask, dtype=dtype, device=device)[None]) + + # Concat back along the batch dimension. + output_masks = torch.cat(output_mask_lst, dim=0) + return output_masks.unsqueeze(dim=1) diff --git a/imcui/third_party/SOLD2/sold2/export_line_features.py b/imcui/third_party/SOLD2/sold2/export_line_features.py new file mode 100644 index 0000000000000000000000000000000000000000..4cbde860a446d758dff254ea5320ca13bb79e6b7 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/export_line_features.py @@ -0,0 +1,74 @@ +""" + Export line detections and descriptors given a list of input images. +""" +import os +import argparse +import cv2 +import numpy as np +import torch +from tqdm import tqdm + +from .experiment import load_config +from .model.line_matcher import LineMatcher + + +def export_descriptors(images_list, ckpt_path, config, device, extension, + output_folder, multiscale=False): + # Extract the image paths + with open(images_list, 'r') as f: + image_files = f.readlines() + image_files = [path.strip('\n') for path in image_files] + + # Initialize the line matcher + line_matcher = LineMatcher( + config["model_cfg"], ckpt_path, device, config["line_detector_cfg"], + config["line_matcher_cfg"], multiscale) + print("\t Successfully initialized model") + + # Run the inference on each image and write the output on disk + for img_path in tqdm(image_files): + img = cv2.imread(img_path, 0) + img = torch.tensor(img[None, None] / 255., dtype=torch.float, + device=device) + + # Run the line detection and description + ref_detection = line_matcher.line_detection(img) + ref_line_seg = ref_detection["line_segments"] + ref_descriptors = ref_detection["descriptor"][0].cpu().numpy() + + # Write the output on disk + img_name = os.path.splitext(os.path.basename(img_path))[0] + output_file = os.path.join(output_folder, img_name + extension) + np.savez_compressed(output_file, line_seg=ref_line_seg, + descriptors=ref_descriptors) + + +if __name__ == "__main__": + # Parse input arguments + parser = argparse.ArgumentParser() + parser.add_argument("--img_list", type=str, required=True, + help="List of input images in a text file.") + parser.add_argument("--output_folder", type=str, required=True, + help="Path to the output folder.") + parser.add_argument("--config", type=str, + default="config/export_line_features.yaml") + parser.add_argument("--checkpoint_path", type=str, + default="pretrained_models/sold2_wireframe.tar") + parser.add_argument("--multiscale", action="store_true", default=False) + parser.add_argument("--extension", type=str, default=None) + args = parser.parse_args() + + # Get the device + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + # Get the model config, extension and checkpoint path + config = load_config(args.config) + ckpt_path = os.path.abspath(args.checkpoint_path) + extension = 'sold2' if args.extension is None else args.extension + extension = "." + extension + + export_descriptors(args.img_list, ckpt_path, config, device, extension, + args.output_folder, args.multiscale) diff --git a/imcui/third_party/SOLD2/sold2/misc/__init__.py b/imcui/third_party/SOLD2/sold2/misc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/sold2/misc/geometry_utils.py b/imcui/third_party/SOLD2/sold2/misc/geometry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50f0478062cd19ebac812bff62b6c3a3d5f124c2 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/misc/geometry_utils.py @@ -0,0 +1,81 @@ +import numpy as np +import torch + + +### Point-related utils + +# Warp a list of points using a homography +def warp_points(points, homography): + # Convert to homogeneous and in xy format + new_points = np.concatenate([points[..., [1, 0]], + np.ones_like(points[..., :1])], axis=-1) + # Warp + new_points = (homography @ new_points.T).T + # Convert back to inhomogeneous and hw format + new_points = new_points[..., [1, 0]] / new_points[..., 2:] + return new_points + + +# Mask out the points that are outside of img_size +def mask_points(points, img_size): + mask = ((points[..., 0] >= 0) + & (points[..., 0] < img_size[0]) + & (points[..., 1] >= 0) + & (points[..., 1] < img_size[1])) + return mask + + +# Convert a tensor [N, 2] or batched tensor [B, N, 2] of N keypoints into +# a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate +def keypoints_to_grid(keypoints, img_size): + n_points = keypoints.size()[-2] + device = keypoints.device + grid_points = keypoints.float() * 2. / torch.tensor( + img_size, dtype=torch.float, device=device) - 1. + grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2) + return grid_points + + +# Return a 2D matrix indicating the local neighborhood of each point +# for a given threshold and two lists of corresponding keypoints +def get_dist_mask(kp0, kp1, valid_mask, dist_thresh): + b_size, n_points, _ = kp0.size() + dist_mask0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1) + dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1) + dist_mask = torch.min(dist_mask0, dist_mask1) + dist_mask = dist_mask <= dist_thresh + dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points, + b_size * n_points) + dist_mask = dist_mask[valid_mask, :][:, valid_mask] + return dist_mask + + +### Line-related utils + +# Sample n points along lines of shape (num_lines, 2, 2) +def sample_line_points(lines, n): + line_points_x = np.linspace(lines[:, 0, 0], lines[:, 1, 0], n, axis=-1) + line_points_y = np.linspace(lines[:, 0, 1], lines[:, 1, 1], n, axis=-1) + line_points = np.stack([line_points_x, line_points_y], axis=2) + return line_points + + +# Return a mask of the valid lines that are within a valid mask of an image +def mask_lines(lines, valid_mask): + h, w = valid_mask.shape + int_lines = np.clip(np.round(lines).astype(int), 0, [h - 1, w - 1]) + h_valid = valid_mask[int_lines[:, 0, 0], int_lines[:, 0, 1]] + w_valid = valid_mask[int_lines[:, 1, 0], int_lines[:, 1, 1]] + valid = h_valid & w_valid + return valid + + +# Return a 2D matrix indicating for each pair of points +# if they are on the same line or not +def get_common_line_mask(line_indices, valid_mask): + b_size, n_points = line_indices.shape + common_mask = line_indices[:, :, None] == line_indices[:, None, :] + common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points, + b_size * n_points) + common_mask = common_mask[valid_mask, :][:, valid_mask] + return common_mask diff --git a/imcui/third_party/SOLD2/sold2/misc/train_utils.py b/imcui/third_party/SOLD2/sold2/misc/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ada35eea660df1f78b9f20d9bf7ed726eaee2c --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/misc/train_utils.py @@ -0,0 +1,74 @@ +""" +This file contains some useful functions for train / val. +""" +import os +import numpy as np +import torch + + +################# +## image utils ## +################# +def convert_image(input_tensor, axis): + """ Convert single channel images to 3-channel images. """ + image_lst = [input_tensor for _ in range(3)] + outputs = np.concatenate(image_lst, axis) + return outputs + + +###################### +## checkpoint utils ## +###################### +def get_latest_checkpoint(checkpoint_root, checkpoint_name, + device=torch.device("cuda")): + """ Get the latest checkpoint or by filename. """ + # Load specific checkpoint + if checkpoint_name is not None: + checkpoint = torch.load( + os.path.join(checkpoint_root, checkpoint_name), + map_location=device) + # Load the latest checkpoint + else: + lastest_checkpoint = sorted(os.listdir(os.path.join( + checkpoint_root, "*.tar")))[-1] + checkpoint = torch.load(os.path.join( + checkpoint_root, lastest_checkpoint), map_location=device) + return checkpoint + + +def remove_old_checkpoints(checkpoint_root, max_ckpt=15): + """ Remove the outdated checkpoints. """ + # Get sorted list of checkpoints + checkpoint_list = sorted( + [_ for _ in os.listdir(os.path.join(checkpoint_root)) + if _.endswith(".tar")]) + + # Get the checkpoints to be removed + if len(checkpoint_list) > max_ckpt: + remove_list = checkpoint_list[:-max_ckpt] + for _ in remove_list: + full_name = os.path.join(checkpoint_root, _) + os.remove(full_name) + print("[Debug] Remove outdated checkpoint %s" % (full_name)) + + +def adapt_checkpoint(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + return new_state_dict + + +################ +## HDF5 utils ## +################ +def parse_h5_data(h5_data): + """ Parse h5 dataset. """ + output_data = {} + for key in h5_data.keys(): + output_data[key] = np.array(h5_data[key]) + + return output_data diff --git a/imcui/third_party/SOLD2/sold2/misc/visualize_util.py b/imcui/third_party/SOLD2/sold2/misc/visualize_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa46877f79724221b7caa423de6916acdc021f8 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/misc/visualize_util.py @@ -0,0 +1,526 @@ +""" Organize some frequently used visualization functions. """ +import cv2 +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import copy +import seaborn as sns + + +# Plot junctions onto the image (return a separate copy) +def plot_junctions(input_image, junctions, junc_size=3, color=None): + """ + input_image: can be 0~1 float or 0~255 uint8. + junctions: Nx2 or 2xN np array. + junc_size: the size of the plotted circles. + """ + # Create image copy + image = copy.copy(input_image) + # Make sure the image is converted to 255 uint8 + if image.dtype == np.uint8: + pass + # A float type image ranging from 0~1 + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: + image = (image * 255.).astype(np.uint8) + # A float type image ranging from 0.~255. + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + image = image.astype(np.uint8) + else: + raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + + # Check whether the image is single channel + if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): + # Squeeze to H*W first + image = image.squeeze() + + # Stack to channle 3 + image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) + + # Junction dimensions should be N*2 + if not len(junctions.shape) == 2: + raise ValueError("[Error] junctions should be 2-dim array.") + + # Always convert to N*2 + if junctions.shape[-1] != 2: + if junctions.shape[0] == 2: + junctions = junctions.T + else: + raise ValueError("[Error] At least one of the two dims should be 2.") + + # Round and convert junctions to int (and check the boundary) + H, W = image.shape[:2] + junctions = (np.round(junctions)).astype(np.int) + junctions[junctions < 0] = 0 + junctions[junctions[:, 0] >= H, 0] = H-1 # (first dim) max bounded by H-1 + junctions[junctions[:, 1] >= W, 1] = W-1 # (second dim) max bounded by W-1 + + # Iterate through all the junctions + num_junc = junctions.shape[0] + if color is None: + color = (0, 255., 0) + for idx in range(num_junc): + # Fetch one junction + junc = junctions[idx, :] + cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, + color=color, thickness=3) + + return image + + +# Plot line segements given junctions and line adjecent map +def plot_line_segments(input_image, junctions, line_map, junc_size=3, + color=(0, 255., 0), line_width=1, plot_survived_junc=True): + """ + input_image: can be 0~1 float or 0~255 uint8. + junctions: Nx2 or 2xN np array. + line_map: NxN np array + junc_size: the size of the plotted circles. + color: color of the line segments (can be string "random") + line_width: width of the drawn segments. + plot_survived_junc: whether we only plot the survived junctions. + """ + # Create image copy + image = copy.copy(input_image) + # Make sure the image is converted to 255 uint8 + if image.dtype == np.uint8: + pass + # A float type image ranging from 0~1 + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: + image = (image * 255.).astype(np.uint8) + # A float type image ranging from 0.~255. + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + image = image.astype(np.uint8) + else: + raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + + # Check whether the image is single channel + if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): + # Squeeze to H*W first + image = image.squeeze() + + # Stack to channle 3 + image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) + + # Junction dimensions should be 2 + if not len(junctions.shape) == 2: + raise ValueError("[Error] junctions should be 2-dim array.") + + # Always convert to N*2 + if junctions.shape[-1] != 2: + if junctions.shape[0] == 2: + junctions = junctions.T + else: + raise ValueError("[Error] At least one of the two dims should be 2.") + + # line_map dimension should be 2 + if not len(line_map.shape) == 2: + raise ValueError("[Error] line_map should be 2-dim array.") + + # Color should be "random" or a list or tuple with length 3 + if color != "random": + if not (isinstance(color, tuple) or isinstance(color, list)): + raise ValueError("[Error] color should have type list or tuple.") + else: + if len(color) != 3: + raise ValueError("[Error] color should be a list or tuple with length 3.") + + # Make a copy of the line_map + line_map_tmp = copy.copy(line_map) + + # Parse line_map back to segment pairs + segments = np.zeros([0, 4]) + for idx in range(junctions.shape[0]): + # if no connectivity, just skip it + if line_map_tmp[idx, :].sum() == 0: + continue + # record the line segment + else: + for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: + p1 = np.flip(junctions[idx, :]) # Convert to xy format + p2 = np.flip(junctions[idx2, :]) # Convert to xy format + segments = np.concatenate((segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), axis=0) + + # Update line_map + line_map_tmp[idx, idx2] = 0 + line_map_tmp[idx2, idx] = 0 + + # Draw segment pairs + for idx in range(segments.shape[0]): + seg = np.round(segments[idx, :]).astype(np.int) + # Decide the color + if color != "random": + color = tuple(color) + else: + color = tuple(np.random.rand(3,)) + cv2.line(image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width) + + # Also draw the junctions + if not plot_survived_junc: + num_junc = junctions.shape[0] + for idx in range(num_junc): + # Fetch one junction + junc = junctions[idx, :] + cv2.circle(image, tuple(np.flip(junc)), radius=junc_size, + color=(0, 255., 0), thickness=3) + # Only plot the junctions which are part of a line segment + else: + for idx in range(segments.shape[0]): + seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format. + cv2.circle(image, tuple(seg[:2]), radius=junc_size, + color=(0, 255., 0), thickness=3) + cv2.circle(image, tuple(seg[2:]), radius=junc_size, + color=(0, 255., 0), thickness=3) + + return image + + +# Plot line segments given Nx4 or Nx2x2 line segments +def plot_line_segments_from_segments(input_image, line_segments, junc_size=3, + color=(0, 255., 0), line_width=1): + # Create image copy + image = copy.copy(input_image) + # Make sure the image is converted to 255 uint8 + if image.dtype == np.uint8: + pass + # A float type image ranging from 0~1 + elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.: + image = (image * 255.).astype(np.uint8) + # A float type image ranging from 0.~255. + elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.: + image = image.astype(np.uint8) + else: + raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.") + + # Check whether the image is single channel + if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): + # Squeeze to H*W first + image = image.squeeze() + + # Stack to channle 3 + image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) + + # Check the if line_segments are in (1) Nx4, or (2) Nx2x2. + H, W, _ = image.shape + # (1) Nx4 format + if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4: + # Round to int32 + line_segments = line_segments.astype(np.int32) + + # Clip H dimension + line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H-1) + line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H-1) + + # Clip W dimension + line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W-1) + line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W-1) + + # Convert to Nx2x2 format + line_segments = np.concatenate( + [np.expand_dims(line_segments[:, :2], axis=1), + np.expand_dims(line_segments[:, 2:], axis=1)], + axis=1 + ) + + # (2) Nx2x2 format + elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2: + # Round to int32 + line_segments = line_segments.astype(np.int32) + + # Clip H dimension + line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H-1) + line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W-1) + + else: + raise ValueError("[Error] line_segments should be either Nx4 or Nx2x2 in HW format.") + + # Draw segment pairs (all segments should be in HW format) + image = image.copy() + for idx in range(line_segments.shape[0]): + seg = np.round(line_segments[idx, :, :]).astype(np.int32) + # Decide the color + if color != "random": + color = tuple(color) + else: + color = tuple(np.random.rand(3,)) + cv2.line(image, tuple(np.flip(seg[0, :])), + tuple(np.flip(seg[1, :])), + color=color, thickness=line_width) + + # Also draw the junctions + cv2.circle(image, tuple(np.flip(seg[0, :])), radius=junc_size, color=(0, 255., 0), thickness=3) + cv2.circle(image, tuple(np.flip(seg[1, :])), radius=junc_size, color=(0, 255., 0), thickness=3) + + return image + + +# Additional functions to visualize multiple images at the same time, +# e.g. for line matching +def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5): + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + figsize = (size*n, size*3/4) if size is not None else None + fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + fig.tight_layout(pad=pad) + + +def plot_keypoints(kpts, colors='lime', ps=4): + """Plot keypoints for existing images. + Args: + kpts: list of ndarrays of size (N, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float. + """ + if not isinstance(colors, list): + colors = [colors] * len(kpts) + axes = plt.gcf().axes + for a, k, c in zip(axes, kpts, colors): + a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) + + +def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): + """Plot matches for a pair of existing images. + Args: + kpts0, kpts1: corresponding keypoints of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + ps: size of the end points (no endpoint if ps=0) + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) + fig.lines += [matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, + alpha=a) + for i in range(len(kpts0))] + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + if ps > 0: + ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2) + ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2) + + +def plot_lines(lines, line_colors='orange', point_colors='cyan', + ps=4, lw=2, indices=(0, 1)): + """Plot lines and endpoints for existing images. + Args: + lines: list of ndarrays of size (N, 2, 2). + colors: string, or list of list of tuples (one for each keypoints). + ps: size of the keypoints as float pixels. + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + if not isinstance(line_colors, list): + line_colors = [line_colors] * len(lines) + if not isinstance(point_colors, list): + point_colors = [point_colors] * len(lines) + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines and junctions + for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): + for i in range(len(l)): + line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]), + (l[i, 0, 1], l[i, 1, 1]), + zorder=1, c=lc, linewidth=lw) + a.add_line(line) + pts = l.reshape(-1, 2) + a.scatter(pts[:, 0], pts[:, 1], + c=pc, s=ps, linewidths=0, zorder=2) + + +def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.): + """Plot matches for a pair of existing images, parametrized by their middle point. + Args: + kpts0, kpts1: corresponding middle points of the lines of size (N, 2). + color: color of each match, string or RGB tuple. Random if not given. + lw: width of the lines. + indices: indices of the images to draw the matches on. + a: alpha opacity of the match lines. + """ + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + ax0, ax1 = ax[indices[0]], ax[indices[1]] + fig.canvas.draw() + + assert len(kpts0) == len(kpts1) + if color is None: + color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() + elif len(color) > 0 and not isinstance(color[0], (tuple, list)): + color = [color] * len(kpts0) + + if lw > 0: + # transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) + fig.lines += [matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), + zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, + alpha=a) + for i in range(len(kpts0))] + + # freeze the axes to prevent the transform to change + ax0.autoscale(enable=False) + ax1.autoscale(enable=False) + + +def plot_color_line_matches(lines, correct_matches=None, + lw=2, indices=(0, 1)): + """Plot line matches for existing images with multiple colors. + Args: + lines: list of ndarrays of size (N, 2, 2). + correct_matches: bool array of size (N,) indicating correct matches. + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + n_lines = len(lines[0]) + colors = sns.color_palette('husl', n_colors=n_lines) + np.random.shuffle(colors) + alphas = np.ones(n_lines) + # If correct_matches is not None, display wrong matches with a low alpha + if correct_matches is not None: + alphas[~np.array(correct_matches)] = 0.2 + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l in zip(axes, lines): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c=colors[i], + alpha=alphas[i], linewidth=lw) for i in range(n_lines)] + + +def plot_color_lines(lines, correct_matches, wrong_matches, + lw=2, indices=(0, 1)): + """Plot line matches for existing images with multiple colors: + green for correct matches, red for wrong ones, and blue for the rest. + Args: + lines: list of ndarrays of size (N, 2, 2). + correct_matches: list of bool arrays of size N with correct matches. + wrong_matches: list of bool arrays of size (N,) with correct matches. + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + # palette = sns.color_palette() + palette = sns.color_palette("hls", 8) + blue = palette[5] # palette[0] + red = palette[0] # palette[3] + green = palette[2] # palette[2] + colors = [np.array([blue] * len(l)) for l in lines] + for i, c in enumerate(colors): + c[np.array(correct_matches[i])] = green + c[np.array(wrong_matches[i])] = red + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l, c in zip(axes, lines, colors): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c=c[i], + linewidth=lw) for i in range(len(l))] + + +def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)): + """ Plot line matches for existing images with multiple colors and + highlight the actually matched subsegments. + Args: + lines: list of ndarrays of size (N, 2, 2). + subsegments: list of ndarrays of size (N, 2, 2). + lw: line width as float pixels. + indices: indices of the images to draw the matches on. + """ + n_lines = len(lines[0]) + colors = sns.cubehelix_palette(start=2, rot=-0.2, dark=0.3, light=.7, + gamma=1.3, hue=1, n_colors=n_lines) + + fig = plt.gcf() + ax = fig.axes + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l, ss in zip(axes, lines, subsegments): + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + + # Draw full line + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c='red', + alpha=0.7, linewidth=lw) for i in range(n_lines)] + + # Draw matched subsegment + endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1])) + fig.lines += [matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, transform=fig.transFigure, c=colors[i], + alpha=1, linewidth=lw) for i in range(n_lines)] \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/model/__init__.py b/imcui/third_party/SOLD2/sold2/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/sold2/model/line_detection.py b/imcui/third_party/SOLD2/sold2/model/line_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d1928515a8494833a8ef6509008f4299cd74c4 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/line_detection.py @@ -0,0 +1,506 @@ +""" +Implementation of the line segment detection module. +""" +import math +import numpy as np +import torch + + +class LineSegmentDetectionModule(object): + """ Module extracting line segments from junctions and line heatmaps. """ + def __init__( + self, detect_thresh, num_samples=64, sampling_method="local_max", + inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2, + max_local_patch_radius=3, lambda_radius=2., + use_candidate_suppression=False, nms_dist_tolerance=3., + use_heatmap_refinement=False, heatmap_refine_cfg=None, + use_junction_refinement=False, junction_refine_cfg=None): + """ + Parameters: + detect_thresh: The probability threshold for mean activation (0. ~ 1.) + num_samples: Number of sampling locations along the line segments. + sampling_method: Sampling method on locations ("bilinear" or "local_max"). + inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. + heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery. + heatmap_high_thresh: The higher threshold for NMS in junction recovery. + max_local_patch_radius: The max patch to be considered in local maximum search. + lambda_radius: The lambda factor in linear local maximum search formulation + use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments. + nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line. + use_heatmap_refinement: Use heatmap refinement method or not. + heatmap_refine_cfg: The configs for heatmap refinement methods. + use_junction_refinement: Use junction refinement method or not. + junction_refine_cfg: The configs for junction refinement methods. + """ + # Line detection parameters + self.detect_thresh = detect_thresh + + # Line sampling parameters + self.num_samples = num_samples + self.sampling_method = sampling_method + self.inlier_thresh = inlier_thresh + self.local_patch_radius = max_local_patch_radius + self.lambda_radius = lambda_radius + + # Detecting junctions on the boundary parameters + self.low_thresh = heatmap_low_thresh + self.high_thresh = heatmap_high_thresh + + # Pre-compute the linspace sampler + self.sampler = np.linspace(0, 1, self.num_samples) + self.torch_sampler = torch.linspace(0, 1, self.num_samples) + + # Long line segment suppression configuration + self.use_candidate_suppression = use_candidate_suppression + self.nms_dist_tolerance = nms_dist_tolerance + + # Heatmap refinement configuration + self.use_heatmap_refinement = use_heatmap_refinement + self.heatmap_refine_cfg = heatmap_refine_cfg + if self.use_heatmap_refinement and self.heatmap_refine_cfg is None: + raise ValueError("[Error] Missing heatmap refinement config.") + + # Junction refinement configuration + self.use_junction_refinement = use_junction_refinement + self.junction_refine_cfg = junction_refine_cfg + if self.use_junction_refinement and self.junction_refine_cfg is None: + raise ValueError("[Error] Missing junction refinement config.") + + def convert_inputs(self, inputs, device): + """ Convert inputs to desired torch tensor. """ + if isinstance(inputs, np.ndarray): + outputs = torch.tensor(inputs, dtype=torch.float32, device=device) + elif isinstance(inputs, torch.Tensor): + outputs = inputs.to(torch.float32).to(device) + else: + raise ValueError( + "[Error] Inputs must either be torch tensor or numpy ndarray.") + + return outputs + + def detect(self, junctions, heatmap, device=torch.device("cpu")): + """ Main function performing line segment detection. """ + # Convert inputs to torch tensor + junctions = self.convert_inputs(junctions, device=device) + heatmap = self.convert_inputs(heatmap, device=device) + + # Perform the heatmap refinement + if self.use_heatmap_refinement: + if self.heatmap_refine_cfg["mode"] == "global": + heatmap = self.refine_heatmap( + heatmap, + self.heatmap_refine_cfg["ratio"], + self.heatmap_refine_cfg["valid_thresh"] + ) + elif self.heatmap_refine_cfg["mode"] == "local": + heatmap = self.refine_heatmap_local( + heatmap, + self.heatmap_refine_cfg["num_blocks"], + self.heatmap_refine_cfg["overlap_ratio"], + self.heatmap_refine_cfg["ratio"], + self.heatmap_refine_cfg["valid_thresh"] + ) + + # Initialize empty line map + num_junctions = junctions.shape[0] + line_map_pred = torch.zeros([num_junctions, num_junctions], + device=device, dtype=torch.int32) + + # Stop if there are not enough junctions + if num_junctions < 2: + return line_map_pred, junctions, heatmap + + # Generate the candidate map + candidate_map = torch.triu(torch.ones( + [num_junctions, num_junctions], device=device, dtype=torch.int32), + diagonal=1) + + # Fetch the image boundary + if len(heatmap.shape) > 2: + H, W, _ = heatmap.shape + else: + H, W = heatmap.shape + + # Optionally perform candidate filtering + if self.use_candidate_suppression: + candidate_map = self.candidate_suppression(junctions, + candidate_map) + + # Fetch the candidates + candidate_index_map = torch.where(candidate_map) + candidate_index_map = torch.cat([candidate_index_map[0][..., None], + candidate_index_map[1][..., None]], + dim=-1) + + # Get the corresponding start and end junctions + candidate_junc_start = junctions[candidate_index_map[:, 0], :] + candidate_junc_end = junctions[candidate_index_map[:, 1], :] + + # Get the sampling locations (N x 64) + sampler = self.torch_sampler.to(device)[None, ...] + cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \ + candidate_junc_end[:, 0:1] * (1 - sampler) + cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \ + candidate_junc_end[:, 1:2] * (1 - sampler) + + # Clip to image boundary + cand_h = torch.clamp(cand_samples_h, min=0, max=H-1) + cand_w = torch.clamp(cand_samples_w, min=0, max=W-1) + + # Local maximum search + if self.sampling_method == "local_max": + # Compute normalized segment lengths + segments_length = torch.sqrt(torch.sum( + (candidate_junc_start.to(torch.float32) - + candidate_junc_end.to(torch.float32)) ** 2, dim=-1)) + normalized_seg_length = (segments_length + / (((H ** 2) + (W ** 2)) ** 0.5)) + + # Perform local max search + num_cand = cand_h.shape[0] + group_size = 10000 + if num_cand > group_size: + num_iter = math.ceil(num_cand / group_size) + sampled_feat_lst = [] + for iter_idx in range(num_iter): + if not iter_idx == num_iter-1: + cand_h_ = cand_h[iter_idx * group_size: + (iter_idx+1) * group_size, :] + cand_w_ = cand_w[iter_idx * group_size: + (iter_idx+1) * group_size, :] + normalized_seg_length_ = normalized_seg_length[ + iter_idx * group_size: (iter_idx+1) * group_size] + else: + cand_h_ = cand_h[iter_idx * group_size:, :] + cand_w_ = cand_w[iter_idx * group_size:, :] + normalized_seg_length_ = normalized_seg_length[ + iter_idx * group_size:] + sampled_feat_ = self.detect_local_max( + heatmap, cand_h_, cand_w_, H, W, + normalized_seg_length_, device) + sampled_feat_lst.append(sampled_feat_) + sampled_feat = torch.cat(sampled_feat_lst, dim=0) + else: + sampled_feat = self.detect_local_max( + heatmap, cand_h, cand_w, H, W, + normalized_seg_length, device) + # Bilinear sampling + elif self.sampling_method == "bilinear": + # Perform bilinear sampling + sampled_feat = self.detect_bilinear( + heatmap, cand_h, cand_w, H, W, device) + else: + raise ValueError("[Error] Unknown sampling method.") + + # [Simple threshold detection] + # detection_results is a mask over all candidates + detection_results = (torch.mean(sampled_feat, dim=-1) + > self.detect_thresh) + + # [Inlier threshold detection] + if self.inlier_thresh > 0.: + inlier_ratio = torch.sum( + sampled_feat > self.detect_thresh, + dim=-1).to(torch.float32) / self.num_samples + detection_results_inlier = inlier_ratio >= self.inlier_thresh + detection_results = detection_results * detection_results_inlier + + # Convert detection results back to line_map_pred + detected_junc_indexes = candidate_index_map[detection_results, :] + line_map_pred[detected_junc_indexes[:, 0], + detected_junc_indexes[:, 1]] = 1 + line_map_pred[detected_junc_indexes[:, 1], + detected_junc_indexes[:, 0]] = 1 + + # Perform junction refinement + if self.use_junction_refinement and len(detected_junc_indexes) > 0: + junctions, line_map_pred = self.refine_junction_perturb( + junctions, line_map_pred, heatmap, H, W, device) + + return line_map_pred, junctions, heatmap + + def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2): + """ Global heatmap refinement method. """ + # Grab the top 10% values + heatmap_values = heatmap[heatmap > valid_thresh] + sorted_values = torch.sort(heatmap_values, descending=True)[0] + top10_len = math.ceil(sorted_values.shape[0] * ratio) + max20 = torch.mean(sorted_values[:top10_len]) + heatmap = torch.clamp(heatmap / max20, min=0., max=1.) + return heatmap + + def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5, + ratio=0.2, valid_thresh=2e-3): + """ Local heatmap refinement method. """ + # Get the shape of the heatmap + H, W = heatmap.shape + increase_ratio = 1 - overlap_ratio + h_block = round(H / (1 + (num_blocks - 1) * increase_ratio)) + w_block = round(W / (1 + (num_blocks - 1) * increase_ratio)) + + count_map = torch.zeros(heatmap.shape, dtype=torch.float, + device=heatmap.device) + heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float, + device=heatmap.device) + # Iterate through each block + for h_idx in range(num_blocks): + for w_idx in range(num_blocks): + # Fetch the heatmap + h_start = round(h_idx * h_block * increase_ratio) + w_start = round(w_idx * w_block * increase_ratio) + h_end = h_start + h_block if h_idx < num_blocks - 1 else H + w_end = w_start + w_block if w_idx < num_blocks - 1 else W + + subheatmap = heatmap[h_start:h_end, w_start:w_end] + if subheatmap.max() > valid_thresh: + subheatmap = self.refine_heatmap( + subheatmap, ratio, valid_thresh=valid_thresh) + + # Aggregate it to the final heatmap + heatmap_output[h_start:h_end, w_start:w_end] += subheatmap + count_map[h_start:h_end, w_start:w_end] += 1 + heatmap_output = torch.clamp(heatmap_output / count_map, + max=1., min=0.) + + return heatmap_output + + def candidate_suppression(self, junctions, candidate_map): + """ Suppress overlapping long lines in the candidate segments. """ + # Define the distance tolerance + dist_tolerance = self.nms_dist_tolerance + + # Compute distance between junction pairs + # (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map + line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1) + - junctions[None, ...]) ** 2, dim=-1) ** 0.5 + + # Fetch all the "detected lines" + seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1)) + start_point_idxs = seg_indexes[0] + end_point_idxs = seg_indexes[1] + start_points = junctions[start_point_idxs, :] + end_points = junctions[end_point_idxs, :] + + # Fetch corresponding entries + line_dists = line_dist_map[start_point_idxs, end_point_idxs] + + # Check whether they are on the line + dir_vecs = ((end_points - start_points) + / torch.norm(end_points - start_points, + dim=-1)[..., None]) + # Get the orthogonal distance + cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1) + cand_vecs_norm = torch.norm(cand_vecs, dim=-1) + # Check whether they are projected directly onto the segment + proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) + / line_dists[..., None, None]) + # proj is num_segs x num_junction x 1 + proj_mask = (proj >=0) * (proj <= 1) + cand_angles = torch.acos( + torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) + / cand_vecs_norm[..., None]) + cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles) + junc_dist_mask = cand_dists <= dist_tolerance + junc_mask = junc_dist_mask * proj_mask + + # Minus starting points + num_segs = start_point_idxs.shape[0] + junc_counts = torch.sum(junc_mask, dim=[1, 2]) + junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), + start_point_idxs].to(torch.int) + junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), + end_point_idxs].to(torch.int) + + # Get the invalid candidate mask + final_mask = junc_counts > 0 + candidate_map[start_point_idxs[final_mask], + end_point_idxs[final_mask]] = 0 + + return candidate_map + + def refine_junction_perturb(self, junctions, line_map_pred, + heatmap, H, W, device): + """ Refine the line endpoints in a similar way as in LSD. """ + # Get the config + junction_refine_cfg = self.junction_refine_cfg + + # Fetch refinement parameters + num_perturbs = junction_refine_cfg["num_perturbs"] + perturb_interval = junction_refine_cfg["perturb_interval"] + side_perturbs = (num_perturbs - 1) // 2 + # Fetch the 2D perturb mat + perturb_vec = torch.arange( + start=-perturb_interval*side_perturbs, + end=perturb_interval*(side_perturbs+1), + step=perturb_interval, device=device) + w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid( + perturb_vec, perturb_vec, perturb_vec, perturb_vec) + perturb_tensor = torch.cat([ + w1_grid[..., None], h1_grid[..., None], + w2_grid[..., None], h2_grid[..., None]], dim=-1) + perturb_tensor_flat = perturb_tensor.view(-1, 2, 2) + + # Fetch the junctions and line_map + junctions = junctions.clone() + line_map = line_map_pred + + # Fetch all the detected lines + detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1)) + start_point_idxs = detected_seg_indexes[0] + end_point_idxs = detected_seg_indexes[1] + start_points = junctions[start_point_idxs, :] + end_points = junctions[end_point_idxs, :] + + line_segments = torch.cat([start_points.unsqueeze(dim=1), + end_points.unsqueeze(dim=1)], dim=1) + + line_segment_candidates = (line_segments.unsqueeze(dim=1) + + perturb_tensor_flat[None, ...]) + # Clip the boundaries + line_segment_candidates[..., 0] = torch.clamp( + line_segment_candidates[..., 0], min=0, max=H - 1) + line_segment_candidates[..., 1] = torch.clamp( + line_segment_candidates[..., 1], min=0, max=W - 1) + + # Iterate through all the segments + refined_segment_lst = [] + num_segments = line_segments.shape[0] + for idx in range(num_segments): + segment = line_segment_candidates[idx, ...] + # Get the corresponding start and end junctions + candidate_junc_start = segment[:, 0, :] + candidate_junc_end = segment[:, 1, :] + + # Get the sampling locations (N x 64) + sampler = self.torch_sampler.to(device)[None, ...] + cand_samples_h = (candidate_junc_start[:, 0:1] * sampler + + candidate_junc_end[:, 0:1] * (1 - sampler)) + cand_samples_w = (candidate_junc_start[:, 1:2] * sampler + + candidate_junc_end[:, 1:2] * (1 - sampler)) + + # Clip to image boundary + cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) + cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) + + # Perform bilinear sampling + segment_feat = self.detect_bilinear( + heatmap, cand_h, cand_w, H, W, device) + segment_results = torch.mean(segment_feat, dim=-1) + max_idx = torch.argmax(segment_results) + refined_segment_lst.append(segment[max_idx, ...][None, ...]) + + # Concatenate back to segments + refined_segments = torch.cat(refined_segment_lst, dim=0) + + # Convert back to junctions and line_map + junctions_new = torch.cat( + [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0) + junctions_new = torch.unique(junctions_new, dim=0) + line_map_new = self.segments_to_line_map(junctions_new, + refined_segments) + + return junctions_new, line_map_new + + def segments_to_line_map(self, junctions, segments): + """ Convert the list of segments to line map. """ + # Create empty line map + device = junctions.device + num_junctions = junctions.shape[0] + line_map = torch.zeros([num_junctions, num_junctions], device=device) + + # Iterate through every segment + for idx in range(segments.shape[0]): + # Get the junctions from a single segement + seg = segments[idx, ...] + junction1 = seg[0, :] + junction2 = seg[1, :] + + # Get index + idx_junction1 = torch.where( + (junctions == junction1).sum(axis=1) == 2)[0] + idx_junction2 = torch.where( + (junctions == junction2).sum(axis=1) == 2)[0] + + # label the corresponding entries + line_map[idx_junction1, idx_junction2] = 1 + line_map[idx_junction2, idx_junction1] = 1 + + return line_map + + def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device): + """ Detection by bilinear sampling. """ + # Get the floor and ceiling locations + cand_h_floor = torch.floor(cand_h).to(torch.long) + cand_h_ceil = torch.ceil(cand_h).to(torch.long) + cand_w_floor = torch.floor(cand_w).to(torch.long) + cand_w_ceil = torch.ceil(cand_w).to(torch.long) + + # Perform the bilinear sampling + cand_samples_feat = ( + heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h) + * (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil] + * (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) + + heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor) + * (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil] + * (cand_h - cand_h_floor) * (cand_w - cand_w_floor)) + + return cand_samples_feat + + def detect_local_max(self, heatmap, cand_h, cand_w, H, W, + normalized_seg_length, device): + """ Detection by local maximum search. """ + # Compute the distance threshold + dist_thresh = (0.5 * (2 ** 0.5) + + self.lambda_radius * normalized_seg_length) + # Make it N x 64 + dist_thresh = torch.repeat_interleave(dist_thresh[..., None], + self.num_samples, dim=-1) + + # Compute the candidate points + cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], + dim=-1) + cand_points_round = torch.round(cand_points) # N x 64 x 2 + + # Construct local patches 9x9 = 81 + patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1), + int(2 * self.local_patch_radius + 1)], + device=device) + patch_center = torch.tensor( + [[self.local_patch_radius, self.local_patch_radius]], + device=device, dtype=torch.float32) + H_patch_points, W_patch_points = torch.where(patch_mask >= 0) + patch_points = torch.cat([H_patch_points[..., None], + W_patch_points[..., None]], dim=-1) + # Fetch the circle region + patch_center_dist = torch.sqrt(torch.sum( + (patch_points - patch_center) ** 2, dim=-1)) + patch_points = (patch_points[patch_center_dist + <= self.local_patch_radius, :]) + # Shift [0, 0] to the center + patch_points = patch_points - self.local_patch_radius + + # Construct local patch mask + patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2) + + patch_points[None, None, ...]) + patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2) + - patch_points_shifted) ** 2, + dim=-1)) + patch_dist_mask = patch_dist < dist_thresh[..., None] + + # Get all points => num_points_center x num_patch_points x 2 + points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, + max=H - 1).to(torch.long) + points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, + max=W - 1).to(torch.long) + points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1) + + # Sample the feature (N x 64 x 81) + sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]] + # Filtering using the valid mask + sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32) + if len(sampled_feat) == 0: + sampled_feat_lmax = torch.empty(0, 64) + else: + sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1) + + return sampled_feat_lmax diff --git a/imcui/third_party/SOLD2/sold2/model/line_detector.py b/imcui/third_party/SOLD2/sold2/model/line_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..2f3d059e130178c482e8e569171ef9e0370424c7 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/line_detector.py @@ -0,0 +1,127 @@ +""" +Line segment detection from raw images. +""" +import time +import numpy as np +import torch +from torch.nn.functional import softmax + +from .model_util import get_model +from .loss import get_loss_and_weights +from .line_detection import LineSegmentDetectionModule +from ..train import convert_junc_predictions +from ..misc.train_utils import adapt_checkpoint + + +def line_map_to_segments(junctions, line_map): + """ Convert a line map to a Nx2x2 list of segments. """ + line_map_tmp = line_map.copy() + + output_segments = np.zeros([0, 2, 2]) + for idx in range(junctions.shape[0]): + # if no connectivity, just skip it + if line_map_tmp[idx, :].sum() == 0: + continue + # Record the line segment + else: + for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: + p1 = junctions[idx, :] # HW format + p2 = junctions[idx2, :] + single_seg = np.concatenate([p1[None, ...], p2[None, ...]], + axis=0) + output_segments = np.concatenate( + (output_segments, single_seg[None, ...]), axis=0) + + # Update line_map + line_map_tmp[idx, idx2] = 0 + line_map_tmp[idx2, idx] = 0 + + return output_segments + + +class LineDetector(object): + def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, + junc_detect_thresh=None): + """ SOLD² line detector taking raw images as input. + Parameters: + model_cfg: config for CNN model + ckpt_path: path to the weights + line_detector_cfg: config file for the line detection module + """ + # Get loss weights if dynamic weighting + _, loss_weights = get_loss_and_weights(model_cfg, device) + self.device = device + + # Initialize the cnn backbone + self.model = get_model(model_cfg, loss_weights) + checkpoint = torch.load(ckpt_path, map_location=self.device) + checkpoint = adapt_checkpoint(checkpoint["model_state_dict"]) + self.model.load_state_dict(checkpoint) + self.model = self.model.to(self.device) + self.model = self.model.eval() + + self.grid_size = model_cfg["grid_size"] + + if junc_detect_thresh is not None: + self.junc_detect_thresh = junc_detect_thresh + else: + self.junc_detect_thresh = model_cfg.get("detection_thresh", 1/65) + self.max_num_junctions = model_cfg.get("max_num_junctions", 300) + + # Initialize the line detector + self.line_detector_cfg = line_detector_cfg + self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) + + def __call__(self, input_image, valid_mask=None, + return_heatmap=False, profile=False): + # Now we restrict input_image to 4D torch tensor + if ((not len(input_image.shape) == 4) + or (not isinstance(input_image, torch.Tensor))): + raise ValueError( + "[Error] the input image should be a 4D torch tensor.") + + # Move the input to corresponding device + input_image = input_image.to(self.device) + + # Forward of the CNN backbone + start_time = time.time() + with torch.no_grad(): + net_outputs = self.model(input_image) + + junc_np = convert_junc_predictions( + net_outputs["junctions"], self.grid_size, + self.junc_detect_thresh, self.max_num_junctions) + if valid_mask is None: + junctions = np.where(junc_np["junc_pred_nms"].squeeze()) + else: + junctions = np.where(junc_np["junc_pred_nms"].squeeze() + * valid_mask) + junctions = np.concatenate( + [junctions[0][..., None], junctions[1][..., None]], axis=-1) + + if net_outputs["heatmap"].shape[1] == 2: + # Convert to single channel directly from here + heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] + else: + heatmap = torch.sigmoid(net_outputs["heatmap"]) + heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0] + + # Run the line detector. + line_map, junctions, heatmap = self.line_detector.detect( + junctions, heatmap, device=self.device) + heatmap = heatmap.cpu().numpy() + if isinstance(line_map, torch.Tensor): + line_map = line_map.cpu().numpy() + if isinstance(junctions, torch.Tensor): + junctions = junctions.cpu().numpy() + line_segments = line_map_to_segments(junctions, line_map) + end_time = time.time() + + outputs = {"line_segments": line_segments} + + if return_heatmap: + outputs["heatmap"] = heatmap + if profile: + outputs["time"] = end_time - start_time + + return outputs diff --git a/imcui/third_party/SOLD2/sold2/model/line_matcher.py b/imcui/third_party/SOLD2/sold2/model/line_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5a003573c91313e2295c75871edcb1c113662a --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/line_matcher.py @@ -0,0 +1,279 @@ +""" +Implements the full pipeline from raw images to line matches. +""" +import time +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.functional import softmax + +from .model_util import get_model +from .loss import get_loss_and_weights +from .metrics import super_nms +from .line_detection import LineSegmentDetectionModule +from .line_matching import WunschLineMatcher +from ..train import convert_junc_predictions +from ..misc.train_utils import adapt_checkpoint +from .line_detector import line_map_to_segments + + +class LineMatcher(object): + """ Full line matcher including line detection and matching + with the Needleman-Wunsch algorithm. """ + def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg, + line_matcher_cfg, multiscale=False, scales=[1., 2.]): + # Get loss weights if dynamic weighting + _, loss_weights = get_loss_and_weights(model_cfg, device) + self.device = device + + # Initialize the cnn backbone + self.model = get_model(model_cfg, loss_weights) + checkpoint = torch.load(ckpt_path, map_location=self.device) + checkpoint = adapt_checkpoint(checkpoint["model_state_dict"]) + self.model.load_state_dict(checkpoint) + self.model = self.model.to(self.device) + self.model = self.model.eval() + + self.grid_size = model_cfg["grid_size"] + self.junc_detect_thresh = model_cfg["detection_thresh"] + self.max_num_junctions = model_cfg.get("max_num_junctions", 300) + + # Initialize the line detector + self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) + self.multiscale = multiscale + self.scales = scales + + # Initialize the line matcher + self.line_matcher = WunschLineMatcher(**line_matcher_cfg) + + # Print some debug messages + for key, val in line_detector_cfg.items(): + print(f"[Debug] {key}: {val}") + # print("[Debug] detect_thresh: %f" % (line_detector_cfg["detect_thresh"])) + # print("[Debug] num_samples: %d" % (line_detector_cfg["num_samples"])) + + + + # Perform line detection and descriptor inference on a single image + def line_detection(self, input_image, valid_mask=None, + desc_only=False, profile=False): + # Restrict input_image to 4D torch tensor + if ((not len(input_image.shape) == 4) + or (not isinstance(input_image, torch.Tensor))): + raise ValueError( + "[Error] the input image should be a 4D torch tensor") + + # Move the input to corresponding device + input_image = input_image.to(self.device) + + # Forward of the CNN backbone + start_time = time.time() + with torch.no_grad(): + net_outputs = self.model(input_image) + + outputs = {"descriptor": net_outputs["descriptors"]} + + if not desc_only: + junc_np = convert_junc_predictions( + net_outputs["junctions"], self.grid_size, + self.junc_detect_thresh, self.max_num_junctions) + if valid_mask is None: + junctions = np.where(junc_np["junc_pred_nms"].squeeze()) + else: + junctions = np.where( + junc_np["junc_pred_nms"].squeeze() * valid_mask) + junctions = np.concatenate([junctions[0][..., None], + junctions[1][..., None]], axis=-1) + + if net_outputs["heatmap"].shape[1] == 2: + # Convert to single channel directly from here + heatmap = softmax( + net_outputs["heatmap"], + dim=1)[:, 1:, :, :].cpu().numpy().transpose(0, 2, 3, 1) + else: + heatmap = torch.sigmoid( + net_outputs["heatmap"]).cpu().numpy().transpose(0, 2, 3, 1) + heatmap = heatmap[0, :, :, 0] + + # Run the line detector. + line_map, junctions, heatmap = self.line_detector.detect( + junctions, heatmap, device=self.device) + if isinstance(line_map, torch.Tensor): + line_map = line_map.cpu().numpy() + if isinstance(junctions, torch.Tensor): + junctions = junctions.cpu().numpy() + outputs["heatmap"] = heatmap.cpu().numpy() + outputs["junctions"] = junctions + + # If it's a line map with multiple detect_thresh and inlier_thresh + if len(line_map.shape) > 2: + num_detect_thresh = line_map.shape[0] + num_inlier_thresh = line_map.shape[1] + line_segments = [] + for detect_idx in range(num_detect_thresh): + line_segments_inlier = [] + for inlier_idx in range(num_inlier_thresh): + line_map_tmp = line_map[detect_idx, inlier_idx, :, :] + line_segments_tmp = line_map_to_segments(junctions, line_map_tmp) + line_segments_inlier.append(line_segments_tmp) + line_segments.append(line_segments_inlier) + else: + line_segments = line_map_to_segments(junctions, line_map) + + outputs["line_segments"] = line_segments + + end_time = time.time() + + if profile: + outputs["time"] = end_time - start_time + + return outputs + + # Perform line detection and descriptor inference at multiple scales + def multiscale_line_detection(self, input_image, valid_mask=None, + desc_only=False, profile=False, + scales=[1., 2.], aggregation='mean'): + # Restrict input_image to 4D torch tensor + if ((not len(input_image.shape) == 4) + or (not isinstance(input_image, torch.Tensor))): + raise ValueError( + "[Error] the input image should be a 4D torch tensor") + + # Move the input to corresponding device + input_image = input_image.to(self.device) + img_size = input_image.shape[2:4] + desc_size = tuple(np.array(img_size) // 4) + + # Run the inference at multiple image scales + start_time = time.time() + junctions, heatmaps, descriptors = [], [], [] + for s in scales: + # Resize the image + resized_img = F.interpolate(input_image, scale_factor=s, + mode='bilinear') + + # Forward of the CNN backbone + with torch.no_grad(): + net_outputs = self.model(resized_img) + + descriptors.append(F.interpolate( + net_outputs["descriptors"], size=desc_size, mode="bilinear")) + + if not desc_only: + junc_prob = convert_junc_predictions( + net_outputs["junctions"], self.grid_size)["junc_pred"] + junctions.append(cv2.resize(junc_prob.squeeze(), + (img_size[1], img_size[0]), + interpolation=cv2.INTER_LINEAR)) + + if net_outputs["heatmap"].shape[1] == 2: + # Convert to single channel directly from here + heatmap = softmax(net_outputs["heatmap"], + dim=1)[:, 1:, :, :] + else: + heatmap = torch.sigmoid(net_outputs["heatmap"]) + heatmaps.append(F.interpolate(heatmap, size=img_size, + mode="bilinear")) + + # Aggregate the results + if aggregation == 'mean': + # Aggregation through the mean activation + descriptors = torch.stack(descriptors, dim=0).mean(0) + else: + # Aggregation through the max activation + descriptors = torch.stack(descriptors, dim=0).max(0)[0] + outputs = {"descriptor": descriptors} + + if not desc_only: + if aggregation == 'mean': + junctions = np.stack(junctions, axis=0).mean(0)[None] + heatmap = torch.stack(heatmaps, dim=0).mean(0)[0, 0, :, :] + heatmap = heatmap.cpu().numpy() + else: + junctions = np.stack(junctions, axis=0).max(0)[None] + heatmap = torch.stack(heatmaps, dim=0).max(0)[0][0, 0, :, :] + heatmap = heatmap.cpu().numpy() + + # Extract junctions + junc_pred_nms = super_nms( + junctions[..., None], self.grid_size, + self.junc_detect_thresh, self.max_num_junctions) + if valid_mask is None: + junctions = np.where(junc_pred_nms.squeeze()) + else: + junctions = np.where(junc_pred_nms.squeeze() * valid_mask) + junctions = np.concatenate([junctions[0][..., None], + junctions[1][..., None]], axis=-1) + + # Run the line detector. + line_map, junctions, heatmap = self.line_detector.detect( + junctions, heatmap, device=self.device) + if isinstance(line_map, torch.Tensor): + line_map = line_map.cpu().numpy() + if isinstance(junctions, torch.Tensor): + junctions = junctions.cpu().numpy() + outputs["heatmap"] = heatmap.cpu().numpy() + outputs["junctions"] = junctions + + # If it's a line map with multiple detect_thresh and inlier_thresh + if len(line_map.shape) > 2: + num_detect_thresh = line_map.shape[0] + num_inlier_thresh = line_map.shape[1] + line_segments = [] + for detect_idx in range(num_detect_thresh): + line_segments_inlier = [] + for inlier_idx in range(num_inlier_thresh): + line_map_tmp = line_map[detect_idx, inlier_idx, :, :] + line_segments_tmp = line_map_to_segments( + junctions, line_map_tmp) + line_segments_inlier.append(line_segments_tmp) + line_segments.append(line_segments_inlier) + else: + line_segments = line_map_to_segments(junctions, line_map) + + outputs["line_segments"] = line_segments + + end_time = time.time() + + if profile: + outputs["time"] = end_time - start_time + + return outputs + + def __call__(self, images, valid_masks=[None, None], profile=False): + # Line detection and descriptor inference on both images + if self.multiscale: + forward_outputs = [ + self.multiscale_line_detection( + images[0], valid_masks[0], profile=profile, + scales=self.scales), + self.multiscale_line_detection( + images[1], valid_masks[1], profile=profile, + scales=self.scales)] + else: + forward_outputs = [ + self.line_detection(images[0], valid_masks[0], + profile=profile), + self.line_detection(images[1], valid_masks[1], + profile=profile)] + line_seg1 = forward_outputs[0]["line_segments"] + line_seg2 = forward_outputs[1]["line_segments"] + desc1 = forward_outputs[0]["descriptor"] + desc2 = forward_outputs[1]["descriptor"] + + # Match the lines in both images + start_time = time.time() + matches = self.line_matcher.forward(line_seg1, line_seg2, + desc1, desc2) + end_time = time.time() + + outputs = {"line_segments": [line_seg1, line_seg2], + "matches": matches} + + if profile: + outputs["line_detection_time"] = (forward_outputs[0]["time"] + + forward_outputs[1]["time"]) + outputs["line_matching_time"] = end_time - start_time + + return outputs diff --git a/imcui/third_party/SOLD2/sold2/model/line_matching.py b/imcui/third_party/SOLD2/sold2/model/line_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..89b71879e3104f9a8b52c1cf5e534cd124fe83b2 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/line_matching.py @@ -0,0 +1,390 @@ +""" +Implementation of the line matching methods. +""" +import numpy as np +import cv2 +import torch +import torch.nn.functional as F + +from ..misc.geometry_utils import keypoints_to_grid + + +class WunschLineMatcher(object): + """ Class matching two sets of line segments + with the Needleman-Wunsch algorithm. """ + def __init__(self, cross_check=True, num_samples=10, min_dist_pts=8, + top_k_candidates=10, grid_size=8, sampling="regular", + line_score=False): + self.cross_check = cross_check + self.num_samples = num_samples + self.min_dist_pts = min_dist_pts + self.top_k_candidates = top_k_candidates + self.grid_size = grid_size + self.line_score = line_score # True to compute saliency on a line + self.sampling_mode = sampling + if sampling not in ["regular", "d2_net", "asl_feat"]: + raise ValueError("Wrong sampling mode: " + sampling) + + def forward(self, line_seg1, line_seg2, desc1, desc2): + """ + Find the best matches between two sets of line segments + and their corresponding descriptors. + """ + img_size1 = (desc1.shape[2] * self.grid_size, + desc1.shape[3] * self.grid_size) + img_size2 = (desc2.shape[2] * self.grid_size, + desc2.shape[3] * self.grid_size) + device = desc1.device + + # Default case when an image has no lines + if len(line_seg1) == 0: + return np.empty((0), dtype=int) + if len(line_seg2) == 0: + return -np.ones(len(line_seg1), dtype=int) + + # Sample points regularly along each line + if self.sampling_mode == "regular": + line_points1, valid_points1 = self.sample_line_points(line_seg1) + line_points2, valid_points2 = self.sample_line_points(line_seg2) + else: + line_points1, valid_points1 = self.sample_salient_points( + line_seg1, desc1, img_size1, self.sampling_mode) + line_points2, valid_points2 = self.sample_salient_points( + line_seg2, desc2, img_size2, self.sampling_mode) + line_points1 = torch.tensor(line_points1.reshape(-1, 2), + dtype=torch.float, device=device) + line_points2 = torch.tensor(line_points2.reshape(-1, 2), + dtype=torch.float, device=device) + + # Extract the descriptors for each point + grid1 = keypoints_to_grid(line_points1, img_size1) + grid2 = keypoints_to_grid(line_points2, img_size2) + desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) + desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) + + # Precompute the distance between line points for every pair of lines + # Assign a score of -1 for unvalid points + scores = desc1.t() @ desc2 + scores[~valid_points1.flatten()] = -1 + scores[:, ~valid_points2.flatten()] = -1 + scores = scores.reshape(len(line_seg1), self.num_samples, + len(line_seg2), self.num_samples) + scores = scores.permute(0, 2, 1, 3) + # scores.shape = (n_lines1, n_lines2, num_samples, num_samples) + + # Pre-filter the line candidates and find the best match for each line + matches = self.filter_and_match_lines(scores) + + # [Optionally] filter matches with mutual nearest neighbor filtering + if self.cross_check: + matches2 = self.filter_and_match_lines( + scores.permute(1, 0, 3, 2)) + mutual = matches2[matches] == np.arange(len(line_seg1)) + matches[~mutual] = -1 + + return matches + + def d2_net_saliency_score(self, desc): + """ Compute the D2-Net saliency score + on a 3D or 4D descriptor. """ + is_3d = len(desc.shape) == 3 + b_size = len(desc) + feat = F.relu(desc) + + # Compute the soft local max + exp = torch.exp(feat) + if is_3d: + sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1, + padding=1) + else: + sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1, + padding=1) + soft_local_max = exp / sum_exp + + # Compute the depth-wise maximum + depth_wise_max = torch.max(feat, dim=1)[0] + depth_wise_max = feat / depth_wise_max.unsqueeze(1) + + # Total saliency score + score = torch.max(soft_local_max * depth_wise_max, dim=1)[0] + normalization = torch.sum(score.reshape(b_size, -1), dim=1) + if is_3d: + normalization = normalization.reshape(b_size, 1) + else: + normalization = normalization.reshape(b_size, 1, 1) + score = score / normalization + return score + + def asl_feat_saliency_score(self, desc): + """ Compute the ASLFeat saliency score on a 3D or 4D descriptor. """ + is_3d = len(desc.shape) == 3 + b_size = len(desc) + + # Compute the soft local peakiness + if is_3d: + local_avg = F.avg_pool1d(desc, kernel_size=3, stride=1, padding=1) + else: + local_avg = F.avg_pool2d(desc, kernel_size=3, stride=1, padding=1) + soft_local_score = F.softplus(desc - local_avg) + + # Compute the depth-wise peakiness + depth_wise_mean = torch.mean(desc, dim=1).unsqueeze(1) + depth_wise_score = F.softplus(desc - depth_wise_mean) + + # Total saliency score + score = torch.max(soft_local_score * depth_wise_score, dim=1)[0] + normalization = torch.sum(score.reshape(b_size, -1), dim=1) + if is_3d: + normalization = normalization.reshape(b_size, 1) + else: + normalization = normalization.reshape(b_size, 1, 1) + score = score / normalization + return score + + def sample_salient_points(self, line_seg, desc, img_size, + saliency_type='d2_net'): + """ + Sample the most salient points along each line segments, with a + minimal distance between each point. Pad the remaining points. + Inputs: + line_seg: an Nx2x2 torch.Tensor. + desc: a NxDxHxW torch.Tensor. + image_size: the original image size. + saliency_type: 'd2_net' or 'asl_feat'. + Outputs: + line_points: an Nxnum_samplesx2 np.array. + valid_points: a boolean Nxnum_samples np.array. + """ + device = desc.device + if not self.line_score: + # Compute the score map + if saliency_type == "d2_net": + score = self.d2_net_saliency_score(desc) + else: + score = self.asl_feat_saliency_score(desc) + + num_lines = len(line_seg) + line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) + + # The number of samples depends on the length of the line + num_samples_lst = np.clip(line_lengths // self.min_dist_pts, + 2, self.num_samples) + line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) + valid_points = np.empty((num_lines, self.num_samples), dtype=bool) + + # Sample the score on a fixed number of points of each line + n_samples_per_region = 4 + for n in np.arange(2, self.num_samples + 1): + sample_rate = n * n_samples_per_region + # Consider all lines where we can fit up to n points + cur_mask = num_samples_lst == n + cur_line_seg = line_seg[cur_mask] + cur_num_lines = len(cur_line_seg) + if cur_num_lines == 0: + continue + line_points_x = np.linspace(cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 0], + sample_rate, axis=-1) + line_points_y = np.linspace(cur_line_seg[:, 0, 1], + cur_line_seg[:, 1, 1], + sample_rate, axis=-1) + cur_line_points = np.stack([line_points_x, line_points_y], + axis=-1).reshape(-1, 2) + # cur_line_points is of shape (n_cur_lines * sample_rate, 2) + cur_line_points = torch.tensor(cur_line_points, dtype=torch.float, + device=device) + grid_points = keypoints_to_grid(cur_line_points, img_size) + + if self.line_score: + # The saliency score is high when the activation are locally + # maximal along the line (and not in a square neigborhood) + line_desc = F.grid_sample(desc, grid_points).squeeze() + line_desc = line_desc.reshape(-1, cur_num_lines, sample_rate) + line_desc = line_desc.permute(1, 0, 2) + if saliency_type == "d2_net": + scores = self.d2_net_saliency_score(line_desc) + else: + scores = self.asl_feat_saliency_score(line_desc) + else: + scores = F.grid_sample(score.unsqueeze(1), + grid_points).squeeze() + + # Take the most salient point in n distinct regions + scores = scores.reshape(-1, n, n_samples_per_region) + best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy() + cur_line_points = cur_line_points.reshape(-1, n, + n_samples_per_region, 2) + cur_line_points = np.take_along_axis( + cur_line_points, best[..., None], axis=2)[:, :, 0] + + # Pad + cur_valid_points = np.ones((cur_num_lines, self.num_samples), + dtype=bool) + cur_valid_points[:, n:] = False + cur_line_points = np.concatenate([ + cur_line_points, + np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)], + axis=1) + + line_points[cur_mask] = cur_line_points + valid_points[cur_mask] = cur_valid_points + + return line_points, valid_points + + def sample_line_points(self, line_seg): + """ + Regularly sample points along each line segments, with a minimal + distance between each point. Pad the remaining points. + Inputs: + line_seg: an Nx2x2 torch.Tensor. + Outputs: + line_points: an Nxnum_samplesx2 np.array. + valid_points: a boolean Nxnum_samples np.array. + """ + num_lines = len(line_seg) + line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) + + # Sample the points separated by at least min_dist_pts along each line + # The number of samples depends on the length of the line + num_samples_lst = np.clip(line_lengths // self.min_dist_pts, + 2, self.num_samples) + line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) + valid_points = np.empty((num_lines, self.num_samples), dtype=bool) + for n in np.arange(2, self.num_samples + 1): + # Consider all lines where we can fit up to n points + cur_mask = num_samples_lst == n + cur_line_seg = line_seg[cur_mask] + line_points_x = np.linspace(cur_line_seg[:, 0, 0], + cur_line_seg[:, 1, 0], + n, axis=-1) + line_points_y = np.linspace(cur_line_seg[:, 0, 1], + cur_line_seg[:, 1, 1], + n, axis=-1) + cur_line_points = np.stack([line_points_x, line_points_y], axis=-1) + + # Pad + cur_num_lines = len(cur_line_seg) + cur_valid_points = np.ones((cur_num_lines, self.num_samples), + dtype=bool) + cur_valid_points[:, n:] = False + cur_line_points = np.concatenate([ + cur_line_points, + np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)], + axis=1) + + line_points[cur_mask] = cur_line_points + valid_points[cur_mask] = cur_valid_points + + return line_points, valid_points + + def filter_and_match_lines(self, scores): + """ + Use the scores to keep the top k best lines, compute the Needleman- + Wunsch algorithm on each candidate pairs, and keep the highest score. + Inputs: + scores: a (N, M, n, n) torch.Tensor containing the pairwise scores + of the elements to match. + Outputs: + matches: a (N) np.array containing the indices of the best match + """ + # Pre-filter the pairs and keep the top k best candidate lines + line_scores1 = scores.max(3)[0] + valid_scores1 = line_scores1 != -1 + line_scores1 = ((line_scores1 * valid_scores1).sum(2) + / valid_scores1.sum(2)) + line_scores2 = scores.max(2)[0] + valid_scores2 = line_scores2 != -1 + line_scores2 = ((line_scores2 * valid_scores2).sum(2) + / valid_scores2.sum(2)) + line_scores = (line_scores1 + line_scores2) / 2 + topk_lines = torch.argsort(line_scores, + dim=1)[:, -self.top_k_candidates:] + scores, topk_lines = scores.cpu().numpy(), topk_lines.cpu().numpy() + # topk_lines.shape = (n_lines1, top_k_candidates) + top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None], + axis=1) + + # Consider the reversed line segments as well + top_scores = np.concatenate([top_scores, top_scores[..., ::-1]], + axis=1) + + # Compute the line distance matrix with Needleman-Wunsch algo and + # retrieve the closest line neighbor + n_lines1, top2k, n, m = top_scores.shape + top_scores = top_scores.reshape(n_lines1 * top2k, n, m) + nw_scores = self.needleman_wunsch(top_scores) + nw_scores = nw_scores.reshape(n_lines1, top2k) + matches = np.mod(np.argmax(nw_scores, axis=1), top2k // 2) + matches = topk_lines[np.arange(n_lines1), matches] + return matches + + def needleman_wunsch(self, scores): + """ + Batched implementation of the Needleman-Wunsch algorithm. + The cost of the InDel operation is set to 0 by subtracting the gap + penalty to the scores. + Inputs: + scores: a (B, N, M) np.array containing the pairwise scores + of the elements to match. + """ + b, n, m = scores.shape + + # Recalibrate the scores to get a gap score of 0 + gap = 0.1 + nw_scores = scores - gap + + # Run the dynamic programming algorithm + nw_grid = np.zeros((b, n + 1, m + 1), dtype=float) + for i in range(n): + for j in range(m): + nw_grid[:, i + 1, j + 1] = np.maximum( + np.maximum(nw_grid[:, i + 1, j], nw_grid[:, i, j + 1]), + nw_grid[:, i, j] + nw_scores[:, i, j]) + + return nw_grid[:, -1, -1] + + def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2): + """ + Compute the OPPOSITE of the NW score for pairs of line segments + and their corresponding descriptors. + """ + num_lines = len(line_seg1) + assert num_lines == len(line_seg2), "The same number of lines is required in pairwise score." + img_size1 = (desc1.shape[2] * self.grid_size, + desc1.shape[3] * self.grid_size) + img_size2 = (desc2.shape[2] * self.grid_size, + desc2.shape[3] * self.grid_size) + device = desc1.device + + # Sample points regularly along each line + line_points1, valid_points1 = self.sample_line_points(line_seg1) + line_points2, valid_points2 = self.sample_line_points(line_seg2) + line_points1 = torch.tensor(line_points1.reshape(-1, 2), + dtype=torch.float, device=device) + line_points2 = torch.tensor(line_points2.reshape(-1, 2), + dtype=torch.float, device=device) + + # Extract the descriptors for each point + grid1 = keypoints_to_grid(line_points1, img_size1) + grid2 = keypoints_to_grid(line_points2, img_size2) + desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) + desc1 = desc1.reshape(-1, num_lines, self.num_samples) + desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) + desc2 = desc2.reshape(-1, num_lines, self.num_samples) + + # Compute the distance between line points for every pair of lines + # Assign a score of -1 for unvalid points + scores = torch.einsum('dns,dnt->nst', desc1, desc2).cpu().numpy() + scores = scores.reshape(num_lines * self.num_samples, + self.num_samples) + scores[~valid_points1.flatten()] = -1 + scores = scores.reshape(num_lines, self.num_samples, self.num_samples) + scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1) + scores[:, ~valid_points2.flatten()] = -1 + scores = scores.reshape(self.num_samples, num_lines, self.num_samples) + scores = scores.transpose(1, 0, 2) + # scores.shape = (num_lines, num_samples, num_samples) + + # Compute the NW score for each pair of lines + pairwise_scores = np.array([self.needleman_wunsch(s) for s in scores]) + return -pairwise_scores diff --git a/imcui/third_party/SOLD2/sold2/model/loss.py b/imcui/third_party/SOLD2/sold2/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..aaad3c67f3fd59db308869901f8a56623901e318 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/loss.py @@ -0,0 +1,445 @@ +""" +Loss function implementations. +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from kornia.geometry import warp_perspective + +from ..misc.geometry_utils import (keypoints_to_grid, get_dist_mask, + get_common_line_mask) + + +def get_loss_and_weights(model_cfg, device=torch.device("cuda")): + """ Get loss functions and either static or dynamic weighting. """ + # Get the global weighting policy + w_policy = model_cfg.get("weighting_policy", "static") + if not w_policy in ["static", "dynamic"]: + raise ValueError("[Error] Not supported weighting policy.") + + loss_func = {} + loss_weight = {} + # Get junction loss function and weight + w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy) + loss_func["junc_loss"] = junc_loss_func.to(device) + loss_weight["w_junc"] = w_junc + + # Get heatmap loss function and weight + w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight( + model_cfg, w_policy, device) + loss_func["heatmap_loss"] = heatmap_loss_func.to(device) + loss_weight["w_heatmap"] = w_heatmap + + # [Optionally] get descriptor loss function and weight + if model_cfg.get("descriptor_loss_func", None) is not None: + w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight( + model_cfg, w_policy) + loss_func["descriptor_loss"] = descriptor_loss_func.to(device) + loss_weight["w_desc"] = w_descriptor + + return loss_func, loss_weight + + +def get_junction_loss_and_weight(model_cfg, global_w_policy): + """ Get the junction loss function and weight. """ + junction_loss_cfg = model_cfg.get("junction_loss_cfg", {}) + + # Get the junction loss weight + w_policy = junction_loss_cfg.get("policy", global_w_policy) + if w_policy == "static": + w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32) + elif w_policy == "dynamic": + w_junc = nn.Parameter( + torch.tensor(model_cfg["w_junc"], dtype=torch.float32), + requires_grad=True) + else: + raise ValueError( + "[Error] Unknown weighting policy for junction loss weight.") + + # Get the junction loss function + junc_loss_name = model_cfg.get("junction_loss_func", "superpoint") + if junc_loss_name == "superpoint": + junc_loss_func = JunctionDetectionLoss(model_cfg["grid_size"], + model_cfg["keep_border_valid"]) + else: + raise ValueError("[Error] Not supported junction loss function.") + + return w_junc, junc_loss_func + + +def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device): + """ Get the heatmap loss function and weight. """ + heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {}) + + # Get the heatmap loss weight + w_policy = heatmap_loss_cfg.get("policy", global_w_policy) + if w_policy == "static": + w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32) + elif w_policy == "dynamic": + w_heatmap = nn.Parameter( + torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32), + requires_grad=True) + else: + raise ValueError( + "[Error] Unknown weighting policy for junction loss weight.") + + # Get the corresponding heatmap loss based on the config + heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy") + if heatmap_loss_name == "cross_entropy": + # Get the heatmap class weight (always static) + heatmap_class_w = model_cfg.get("w_heatmap_class", 1.) + class_weight = torch.tensor( + np.array([1., heatmap_class_w])).to(torch.float).to(device) + heatmap_loss_func = HeatmapLoss(class_weight=class_weight) + else: + raise ValueError("[Error] Not supported heatmap loss function.") + + return w_heatmap, heatmap_loss_func + + +def get_descriptor_loss_and_weight(model_cfg, global_w_policy): + """ Get the descriptor loss function and weight. """ + descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {}) + + # Get the descriptor loss weight + w_policy = descriptor_loss_cfg.get("policy", global_w_policy) + if w_policy == "static": + w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32) + elif w_policy == "dynamic": + w_descriptor = nn.Parameter(torch.tensor(model_cfg["w_desc"], + dtype=torch.float32), requires_grad=True) + else: + raise ValueError( + "[Error] Unknown weighting policy for descriptor loss weight.") + + # Get the descriptor loss function + descriptor_loss_name = model_cfg.get("descriptor_loss_func", + "regular_sampling") + if descriptor_loss_name == "regular_sampling": + descriptor_loss_func = TripletDescriptorLoss( + descriptor_loss_cfg["grid_size"], + descriptor_loss_cfg["dist_threshold"], + descriptor_loss_cfg["margin"]) + else: + raise ValueError("[Error] Not supported descriptor loss function.") + + return w_descriptor, descriptor_loss_func + + +def space_to_depth(input_tensor, grid_size): + """ PixelUnshuffle for pytorch. """ + N, C, H, W = input_tensor.size() + # (N, C, H//bs, bs, W//bs, bs) + x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size) + # (N, bs, bs, C, H//bs, W//bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() + # (N, C*bs^2, H//bs, W//bs) + x = x.view(N, C * (grid_size ** 2), H // grid_size, W // grid_size) + return x + + +def junction_detection_loss(junction_map, junc_predictions, valid_mask=None, + grid_size=8, keep_border=True): + """ Junction detection loss. """ + # Convert junc_map to channel tensor + junc_map = space_to_depth(junction_map, grid_size) + map_shape = junc_map.shape[-2:] + batch_size = junc_map.shape[0] + dust_bin_label = torch.ones( + [batch_size, 1, map_shape[0], + map_shape[1]]).to(junc_map.device).to(torch.int) + junc_map = torch.cat([junc_map*2, dust_bin_label], dim=1) + labels = torch.argmax( + junc_map.to(torch.float) + + torch.distributions.Uniform(0, 0.1).sample(junc_map.shape).to(junc_map.device), + dim=1) + + # Also convert the valid mask to channel tensor + valid_mask = (torch.ones(junction_map.shape) if valid_mask is None + else valid_mask) + valid_mask = space_to_depth(valid_mask, grid_size) + + # Compute junction loss on the border patch or not + if keep_border: + valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int), + dim=1, keepdim=True) > 0 + else: + valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int), + dim=1, keepdim=True) >= grid_size * grid_size + + # Compute the classification loss + loss_func = nn.CrossEntropyLoss(reduction="none") + # The loss still need NCHW format + loss = loss_func(input=junc_predictions, + target=labels.to(torch.long)) + + # Weighted sum by the valid mask + loss_ = torch.sum(loss * torch.squeeze(valid_mask.to(torch.float), + dim=1), dim=[0, 1, 2]) + loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), + dim=1)) + + return loss_final + + +def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, + class_weight=None): + """ Heatmap prediction loss. """ + # Compute the classification loss on each pixel + if class_weight is None: + loss_func = nn.CrossEntropyLoss(reduction="none") + else: + loss_func = nn.CrossEntropyLoss(class_weight, reduction="none") + + loss = loss_func(input=heatmap_pred, + target=torch.squeeze(heatmap_gt.to(torch.long), dim=1)) + + # Weighted sum by the valid mask + # Sum over H and W + loss_spatial_sum = torch.sum(loss * torch.squeeze( + valid_mask.to(torch.float), dim=1), dim=[1, 2]) + valid_spatial_sum = torch.sum(torch.squeeze(valid_mask.to(torch.float32), + dim=1), dim=[1, 2]) + # Mean to single scalar over batch dimension + loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum) + + return loss + + +class JunctionDetectionLoss(nn.Module): + """ Junction detection loss. """ + def __init__(self, grid_size, keep_border): + super(JunctionDetectionLoss, self).__init__() + self.grid_size = grid_size + self.keep_border = keep_border + + def forward(self, prediction, target, valid_mask=None): + return junction_detection_loss(target, prediction, valid_mask, + self.grid_size, self.keep_border) + + +class HeatmapLoss(nn.Module): + """ Heatmap prediction loss. """ + def __init__(self, class_weight): + super(HeatmapLoss, self).__init__() + self.class_weight = class_weight + + def forward(self, prediction, target, valid_mask=None): + return heatmap_loss(target, prediction, valid_mask, self.class_weight) + + +class RegularizationLoss(nn.Module): + """ Module for regularization loss. """ + def __init__(self): + super(RegularizationLoss, self).__init__() + self.name = "regularization_loss" + self.loss_init = torch.zeros([]) + + def forward(self, loss_weights): + # Place it to the same device + loss = self.loss_init.to(loss_weights["w_junc"].device) + for _, val in loss_weights.items(): + if isinstance(val, nn.Parameter): + loss += val + + return loss + + +def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices, + epoch, grid_size=8, dist_threshold=8, + init_dist_threshold=64, margin=1): + """ Regular triplet loss for descriptor learning. """ + b_size, _, Hc, Wc = desc_pred1.size() + img_size = (Hc * grid_size, Wc * grid_size) + device = desc_pred1.device + + # Extract valid keypoints + n_points = line_indices.size()[1] + valid_points = line_indices.bool().flatten() + n_correct_points = torch.sum(valid_points).item() + if n_correct_points == 0: + return torch.tensor(0., dtype=torch.float, device=device) + + # Check which keypoints are too close to be matched + # dist_threshold is decreased at each epoch for easier training + dist_threshold = max(dist_threshold, + 2 * init_dist_threshold // (epoch + 1)) + dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold) + + # Additionally ban negative mining along the same line + common_line_mask = get_common_line_mask(line_indices, valid_points) + dist_mask = dist_mask | common_line_mask + + # Convert the keypoints to a grid suitable for interpolation + grid1 = keypoints_to_grid(points1, img_size) + grid2 = keypoints_to_grid(points2, img_size) + + # Extract the descriptors + desc1 = F.grid_sample(desc_pred1, grid1).permute( + 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc1 = F.normalize(desc1, dim=1) + desc2 = F.grid_sample(desc_pred2, grid2).permute( + 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc2 = F.normalize(desc2, dim=1) + desc_dists = 2 - 2 * (desc1 @ desc2.t()) + + # Positive distance loss + pos_dist = torch.diag(desc_dists) + + # Negative distance loss + max_dist = torch.tensor(4., dtype=torch.float, device=device) + desc_dists[ + torch.arange(n_correct_points, dtype=torch.long), + torch.arange(n_correct_points, dtype=torch.long)] = max_dist + desc_dists[dist_mask] = max_dist + neg_dist = torch.min(torch.min(desc_dists, dim=1)[0], + torch.min(desc_dists, dim=0)[0]) + + triplet_loss = F.relu(margin + pos_dist - neg_dist) + return triplet_loss, grid1, grid2, valid_points + + +class TripletDescriptorLoss(nn.Module): + """ Triplet descriptor loss. """ + def __init__(self, grid_size, dist_threshold, margin): + super(TripletDescriptorLoss, self).__init__() + self.grid_size = grid_size + self.init_dist_threshold = 64 + self.dist_threshold = dist_threshold + self.margin = margin + + def forward(self, desc_pred1, desc_pred2, points1, + points2, line_indices, epoch): + return self.descriptor_loss(desc_pred1, desc_pred2, points1, + points2, line_indices, epoch) + + # The descriptor loss based on regularly sampled points along the lines + def descriptor_loss(self, desc_pred1, desc_pred2, points1, + points2, line_indices, epoch): + return torch.mean(triplet_loss( + desc_pred1, desc_pred2, points1, points2, line_indices, epoch, + self.grid_size, self.dist_threshold, self.init_dist_threshold, + self.margin)[0]) + + +class TotalLoss(nn.Module): + """ Total loss summing junction, heatma, descriptor + and regularization losses. """ + def __init__(self, loss_funcs, loss_weights, weighting_policy): + super(TotalLoss, self).__init__() + # Whether we need to compute the descriptor loss + self.compute_descriptors = "descriptor_loss" in loss_funcs.keys() + + self.loss_funcs = loss_funcs + self.loss_weights = loss_weights + self.weighting_policy = weighting_policy + + # Always add regularization loss (it will return zero if not used) + self.loss_funcs["reg_loss"] = RegularizationLoss().cuda() + + def forward(self, junc_pred, junc_target, heatmap_pred, + heatmap_target, valid_mask=None): + """ Detection only loss. """ + # Compute the junction loss + junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, + valid_mask) + # Compute the heatmap loss + heatmap_loss = self.loss_funcs["heatmap_loss"]( + heatmap_pred, heatmap_target, valid_mask) + + # Compute the total loss. + if self.weighting_policy == "dynamic": + reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) + total_loss = junc_loss * torch.exp(-self.loss_weights["w_junc"]) + \ + heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) + \ + reg_loss + + return { + "total_loss": total_loss, + "junc_loss": junc_loss, + "heatmap_loss": heatmap_loss, + "reg_loss": reg_loss, + "w_junc": torch.exp(-self.loss_weights["w_junc"]).item(), + "w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(), + } + + elif self.weighting_policy == "static": + total_loss = junc_loss * self.loss_weights["w_junc"] + \ + heatmap_loss * self.loss_weights["w_heatmap"] + + return { + "total_loss": total_loss, + "junc_loss": junc_loss, + "heatmap_loss": heatmap_loss + } + + else: + raise ValueError("[Error] Unknown weighting policy.") + + def forward_descriptors(self, + junc_map_pred1, junc_map_pred2, junc_map_target1, + junc_map_target2, heatmap_pred1, heatmap_pred2, heatmap_target1, + heatmap_target2, line_points1, line_points2, line_indices, + desc_pred1, desc_pred2, epoch, valid_mask1=None, + valid_mask2=None): + """ Loss for detection + description. """ + # Compute junction loss + junc_loss = self.loss_funcs["junc_loss"]( + torch.cat([junc_map_pred1, junc_map_pred2], dim=0), + torch.cat([junc_map_target1, junc_map_target2], dim=0), + torch.cat([valid_mask1, valid_mask2], dim=0) + ) + # Get junction loss weight (dynamic or not) + if isinstance(self.loss_weights["w_junc"], nn.Parameter): + w_junc = torch.exp(-self.loss_weights["w_junc"]) + else: + w_junc = self.loss_weights["w_junc"] + + # Compute heatmap loss + heatmap_loss = self.loss_funcs["heatmap_loss"]( + torch.cat([heatmap_pred1, heatmap_pred2], dim=0), + torch.cat([heatmap_target1, heatmap_target2], dim=0), + torch.cat([valid_mask1, valid_mask2], dim=0) + ) + # Get heatmap loss weight (dynamic or not) + if isinstance(self.loss_weights["w_heatmap"], nn.Parameter): + w_heatmap = torch.exp(-self.loss_weights["w_heatmap"]) + else: + w_heatmap = self.loss_weights["w_heatmap"] + + # Compute the descriptor loss + descriptor_loss = self.loss_funcs["descriptor_loss"]( + desc_pred1, desc_pred2, line_points1, + line_points2, line_indices, epoch) + # Get descriptor loss weight (dynamic or not) + if isinstance(self.loss_weights["w_desc"], nn.Parameter): + w_descriptor = torch.exp(-self.loss_weights["w_desc"]) + else: + w_descriptor = self.loss_weights["w_desc"] + + # Update the total loss + total_loss = (junc_loss * w_junc + + heatmap_loss * w_heatmap + + descriptor_loss * w_descriptor) + outputs = { + "junc_loss": junc_loss, + "heatmap_loss": heatmap_loss, + "w_junc": w_junc.item() \ + if isinstance(w_junc, nn.Parameter) else w_junc, + "w_heatmap": w_heatmap.item() \ + if isinstance(w_heatmap, nn.Parameter) else w_heatmap, + "descriptor_loss": descriptor_loss, + "w_desc": w_descriptor.item() \ + if isinstance(w_descriptor, nn.Parameter) else w_descriptor + } + + # Compute the regularization loss + reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) + total_loss += reg_loss + outputs.update({ + "reg_loss": reg_loss, + "total_loss": total_loss + }) + + return outputs diff --git a/imcui/third_party/SOLD2/sold2/model/lr_scheduler.py b/imcui/third_party/SOLD2/sold2/model/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..3faa4f68a67564719008a932b40c16c5e908949f --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/lr_scheduler.py @@ -0,0 +1,22 @@ +""" +This file implements different learning rate schedulers +""" +import torch + + +def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer): + """ Get the learning rate scheduler according to the config. """ + # If no lr_decay is specified => return None + if (lr_decay == False) or (lr_decay_cfg is None): + schduler = None + # Exponential decay + elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"): + schduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, + gamma=lr_decay_cfg["gamma"] + ) + # Unknown policy + else: + raise ValueError("[Error] Unknow learning rate decay policy!") + + return schduler \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/model/metrics.py b/imcui/third_party/SOLD2/sold2/model/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..0894a7207ee4afa344cb332c605c715b14db73a4 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/metrics.py @@ -0,0 +1,528 @@ +""" +This file implements the evaluation metrics. +""" +import torch +import torch.nn.functional as F +import numpy as np +from torchvision.ops.boxes import batched_nms + +from ..misc.geometry_utils import keypoints_to_grid + + +class Metrics(object): + """ Metric evaluation calculator. """ + def __init__(self, detection_thresh, prob_thresh, grid_size, + junc_metric_lst=None, heatmap_metric_lst=None, + pr_metric_lst=None, desc_metric_lst=None): + # List supported metrics + self.supported_junc_metrics = ["junc_precision", "junc_precision_nms", + "junc_recall", "junc_recall_nms"] + self.supported_heatmap_metrics = ["heatmap_precision", + "heatmap_recall"] + self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] + self.supported_desc_metrics = ["matching_score"] + + # If metric_lst is None, default to use all metrics + if junc_metric_lst is None: + self.junc_metric_lst = self.supported_junc_metrics + else: + self.junc_metric_lst = junc_metric_lst + if heatmap_metric_lst is None: + self.heatmap_metric_lst = self.supported_heatmap_metrics + else: + self.heatmap_metric_lst = heatmap_metric_lst + if pr_metric_lst is None: + self.pr_metric_lst = self.supported_pr_metrics + else: + self.pr_metric_lst = pr_metric_lst + # For the descriptors, the default None assumes no desc metric at all + if desc_metric_lst is None: + self.desc_metric_lst = [] + elif desc_metric_lst == 'all': + self.desc_metric_lst = self.supported_desc_metrics + else: + self.desc_metric_lst = desc_metric_lst + + if not self._check_metrics(): + raise ValueError( + "[Error] Some elements in the metric_lst are invalid.") + + # Metric mapping table + self.metric_table = { + "junc_precision": junction_precision(detection_thresh), + "junc_precision_nms": junction_precision(detection_thresh), + "junc_recall": junction_recall(detection_thresh), + "junc_recall_nms": junction_recall(detection_thresh), + "heatmap_precision": heatmap_precision(prob_thresh), + "heatmap_recall": heatmap_recall(prob_thresh), + "junc_pr": junction_pr(), + "junc_nms_pr": junction_pr(), + "matching_score": matching_score(grid_size) + } + + # Initialize the results + self.metric_results = {} + for key in self.metric_table.keys(): + self.metric_results[key] = 0. + + def evaluate(self, junc_pred, junc_pred_nms, junc_gt, heatmap_pred, + heatmap_gt, valid_mask, line_points1=None, line_points2=None, + desc_pred1=None, desc_pred2=None, valid_points=None): + """ Perform evaluation. """ + for metric in self.junc_metric_lst: + # If nms metrics then use nms to compute it. + if "nms" in metric: + junc_pred_input = junc_pred_nms + # Use normal inputs instead. + else: + junc_pred_input = junc_pred + self.metric_results[metric] = self.metric_table[metric]( + junc_pred_input, junc_gt, valid_mask) + + for metric in self.heatmap_metric_lst: + self.metric_results[metric] = self.metric_table[metric]( + heatmap_pred, heatmap_gt, valid_mask) + + for metric in self.pr_metric_lst: + if "nms" in metric: + self.metric_results[metric] = self.metric_table[metric]( + junc_pred_nms, junc_gt, valid_mask) + else: + self.metric_results[metric] = self.metric_table[metric]( + junc_pred, junc_gt, valid_mask) + + for metric in self.desc_metric_lst: + self.metric_results[metric] = self.metric_table[metric]( + line_points1, line_points2, desc_pred1, + desc_pred2, valid_points) + + def _check_metrics(self): + """ Check if all input metrics are valid. """ + flag = True + for metric in self.junc_metric_lst: + if not metric in self.supported_junc_metrics: + flag = False + break + for metric in self.heatmap_metric_lst: + if not metric in self.supported_heatmap_metrics: + flag = False + break + for metric in self.desc_metric_lst: + if not metric in self.supported_desc_metrics: + flag = False + break + + return flag + + +class AverageMeter(object): + def __init__(self, junc_metric_lst=None, heatmap_metric_lst=None, + is_training=True, desc_metric_lst=None): + # List supported metrics + self.supported_junc_metrics = ["junc_precision", "junc_precision_nms", + "junc_recall", "junc_recall_nms"] + self.supported_heatmap_metrics = ["heatmap_precision", + "heatmap_recall"] + self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] + self.supported_desc_metrics = ["matching_score"] + # Record loss in training mode + # if is_training: + self.supported_loss = [ + "junc_loss", "heatmap_loss", "descriptor_loss", "total_loss"] + + self.is_training = is_training + + # If metric_lst is None, default to use all metrics + if junc_metric_lst is None: + self.junc_metric_lst = self.supported_junc_metrics + else: + self.junc_metric_lst = junc_metric_lst + if heatmap_metric_lst is None: + self.heatmap_metric_lst = self.supported_heatmap_metrics + else: + self.heatmap_metric_lst = heatmap_metric_lst + # For the descriptors, the default None assumes no desc metric at all + if desc_metric_lst is None: + self.desc_metric_lst = [] + elif desc_metric_lst == 'all': + self.desc_metric_lst = self.supported_desc_metrics + else: + self.desc_metric_lst = desc_metric_lst + + if not self._check_metrics(): + raise ValueError( + "[Error] Some elements in the metric_lst are invalid.") + + # Initialize the results + self.metric_results = {} + for key in (self.supported_junc_metrics + + self.supported_heatmap_metrics + + self.supported_loss + self.supported_desc_metrics): + self.metric_results[key] = 0. + for key in self.supported_pr_metrics: + zero_lst = [0 for _ in range(50)] + self.metric_results[key] = { + "tp": zero_lst, + "tn": zero_lst, + "fp": zero_lst, + "fn": zero_lst, + "precision": zero_lst, + "recall": zero_lst + } + + # Initialize total count + self.count = 0 + + def update(self, metrics, loss_dict=None, num_samples=1): + # loss should be given in the training mode + if self.is_training and (loss_dict is None): + raise ValueError( + "[Error] loss info should be given in the training mode.") + + # update total counts + self.count += num_samples + + # update all the metrics + for met in (self.supported_junc_metrics + + self.supported_heatmap_metrics + + self.supported_desc_metrics): + self.metric_results[met] += (num_samples + * metrics.metric_results[met]) + + # Update all the losses + for loss in loss_dict.keys(): + self.metric_results[loss] += num_samples * loss_dict[loss] + + # Update all pr counts + for pr_met in self.supported_pr_metrics: + # Update all tp, tn, fp, fn, precision, and recall. + for key in metrics.metric_results[pr_met].keys(): + # Update each interval + for idx in range(len(self.metric_results[pr_met][key])): + self.metric_results[pr_met][key][idx] += ( + num_samples + * metrics.metric_results[pr_met][key][idx]) + + def average(self): + results = {} + for met in self.metric_results.keys(): + # Skip pr curve metrics + if not met in self.supported_pr_metrics: + results[met] = self.metric_results[met] / self.count + # Only update precision and recall in pr metrics + else: + met_results = { + "tp": self.metric_results[met]["tp"], + "tn": self.metric_results[met]["tn"], + "fp": self.metric_results[met]["fp"], + "fn": self.metric_results[met]["fn"], + "precision": [], + "recall": [] + } + for idx in range(len(self.metric_results[met]["precision"])): + met_results["precision"].append( + self.metric_results[met]["precision"][idx] + / self.count) + met_results["recall"].append( + self.metric_results[met]["recall"][idx] / self.count) + + results[met] = met_results + + return results + + def _check_metrics(self): + """ Check if all input metrics are valid. """ + flag = True + for metric in self.junc_metric_lst: + if not metric in self.supported_junc_metrics: + flag = False + break + for metric in self.heatmap_metric_lst: + if not metric in self.supported_heatmap_metrics: + flag = False + break + for metric in self.desc_metric_lst: + if not metric in self.supported_desc_metrics: + flag = False + break + + return flag + + +class junction_precision(object): + """ Junction precision. """ + def __init__(self, detection_thresh): + self.detection_thresh = detection_thresh + + # Compute the evaluation result + def __call__(self, junc_pred, junc_gt, valid_mask): + # Convert prediction to discrete detection + junc_pred = (junc_pred >= self.detection_thresh).astype(np.int) + junc_pred = junc_pred * valid_mask.squeeze() + + # Deal with the corner case of the prediction + if np.sum(junc_pred) > 0: + precision = (np.sum(junc_pred * junc_gt.squeeze()) + / np.sum(junc_pred)) + else: + precision = 0 + + return float(precision) + + +class junction_recall(object): + """ Junction recall. """ + def __init__(self, detection_thresh): + self.detection_thresh = detection_thresh + + # Compute the evaluation result + def __call__(self, junc_pred, junc_gt, valid_mask): + # Convert prediction to discrete detection + junc_pred = (junc_pred >= self.detection_thresh).astype(np.int) + junc_pred = junc_pred * valid_mask.squeeze() + + # Deal with the corner case of the recall. + if np.sum(junc_gt): + recall = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_gt) + else: + recall = 0 + + return float(recall) + + +class junction_pr(object): + """ Junction precision-recall info. """ + def __init__(self, num_threshold=50): + self.max = 0.4 + step = self.max / num_threshold + self.min = step + self.intervals = np.flip(np.arange(self.min, self.max + step, step)) + + def __call__(self, junc_pred_raw, junc_gt, valid_mask): + tp_lst = [] + fp_lst = [] + tn_lst = [] + fn_lst = [] + precision_lst = [] + recall_lst = [] + + valid_mask = valid_mask.squeeze() + # Iterate through all the thresholds + for thresh in list(self.intervals): + # Convert prediction to discrete detection + junc_pred = (junc_pred_raw >= thresh).astype(np.int) + junc_pred = junc_pred * valid_mask + + # Compute tp, fp, tn, fn + junc_gt = junc_gt.squeeze() + tp = np.sum(junc_pred * junc_gt) + tn = np.sum((junc_pred == 0).astype(np.float) + * (junc_gt == 0).astype(np.float) * valid_mask) + fp = np.sum((junc_pred == 1).astype(np.float) + * (junc_gt == 0).astype(np.float) * valid_mask) + fn = np.sum((junc_pred == 0).astype(np.float) + * (junc_gt == 1).astype(np.float) * valid_mask) + + tp_lst.append(tp) + tn_lst.append(tn) + fp_lst.append(fp) + fn_lst.append(fn) + precision_lst.append(tp / (tp + fp)) + recall_lst.append(tp / (tp + fn)) + + return { + "tp": np.array(tp_lst), + "tn": np.array(tn_lst), + "fp": np.array(fp_lst), + "fn": np.array(fn_lst), + "precision": np.array(precision_lst), + "recall": np.array(recall_lst) + } + + +class heatmap_precision(object): + """ Heatmap precision. """ + def __init__(self, prob_thresh): + self.prob_thresh = prob_thresh + + def __call__(self, heatmap_pred, heatmap_gt, valid_mask): + # Assume NHWC (Handle L1 and L2 cases) NxHxWx1 + heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh) + heatmap_pred = heatmap_pred * valid_mask.squeeze() + + # Deal with the corner case of the prediction + if np.sum(heatmap_pred) > 0: + precision = (np.sum(heatmap_pred * heatmap_gt.squeeze()) + / np.sum(heatmap_pred)) + else: + precision = 0. + + return precision + + +class heatmap_recall(object): + """ Heatmap recall. """ + def __init__(self, prob_thresh): + self.prob_thresh = prob_thresh + + def __call__(self, heatmap_pred, heatmap_gt, valid_mask): + # Assume NHWC (Handle L1 and L2 cases) NxHxWx1 + heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh) + heatmap_pred = heatmap_pred * valid_mask.squeeze() + + # Deal with the corner case of the ground truth + if np.sum(heatmap_gt) > 0: + recall = (np.sum(heatmap_pred * heatmap_gt.squeeze()) + / np.sum(heatmap_gt)) + else: + recall = 0. + + return recall + + +class matching_score(object): + """ Descriptors matching score. """ + def __init__(self, grid_size): + self.grid_size = grid_size + + def __call__(self, points1, points2, desc_pred1, + desc_pred2, line_indices): + b_size, _, Hc, Wc = desc_pred1.size() + img_size = (Hc * self.grid_size, Wc * self.grid_size) + device = desc_pred1.device + + # Extract valid keypoints + n_points = line_indices.size()[1] + valid_points = line_indices.bool().flatten() + n_correct_points = torch.sum(valid_points).item() + if n_correct_points == 0: + return torch.tensor(0., dtype=torch.float, device=device) + + # Convert the keypoints to a grid suitable for interpolation + grid1 = keypoints_to_grid(points1, img_size) + grid2 = keypoints_to_grid(points2, img_size) + + # Extract the descriptors + desc1 = F.grid_sample(desc_pred1, grid1).permute( + 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc1 = F.normalize(desc1, dim=1) + desc2 = F.grid_sample(desc_pred2, grid2).permute( + 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] + desc2 = F.normalize(desc2, dim=1) + desc_dists = 2 - 2 * (desc1 @ desc2.t()) + + # Compute percentage of correct matches + matches0 = torch.min(desc_dists, dim=1)[1] + matches1 = torch.min(desc_dists, dim=0)[1] + matching_score = (matches1[matches0] + == torch.arange(len(matches0)).to(device)) + matching_score = matching_score.float().mean() + return matching_score + + +def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0): + """ Non-maximum suppression adapted from SuperPoint. """ + # Iterate through batch dimension + im_h = prob_predictions.shape[1] + im_w = prob_predictions.shape[2] + output_lst = [] + for i in range(prob_predictions.shape[0]): + # print(i) + prob_pred = prob_predictions[i, ...] + # Filter the points using prob_thresh + coord = np.where(prob_pred >= prob_thresh) # HW format + points = np.concatenate((coord[0][..., None], coord[1][..., None]), + axis=1) # HW format + + # Get the probability score + prob_score = prob_pred[points[:, 0], points[:, 1]] + + # Perform super nms + # Modify the in_points to xy format (instead of HW format) + in_points = np.concatenate((coord[1][..., None], coord[0][..., None], + prob_score), axis=1).T + keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh) + # Remember to flip outputs back to HW format + keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T) + keep_score = keep_points_[-1, :].T + + # Whether we only keep the topk value + if (top_k > 0) or (top_k is None): + k = min([keep_points.shape[0], top_k]) + keep_points = keep_points[:k, :] + keep_score = keep_score[:k] + + # Re-compose the probability map + output_map = np.zeros([im_h, im_w]) + output_map[keep_points[:, 0].astype(np.int), + keep_points[:, 1].astype(np.int)] = keep_score.squeeze() + + output_lst.append(output_map[None, ...]) + + return np.concatenate(output_lst, axis=0) + + +def nms_fast(in_corners, H, W, dist_thresh): + """ + Run a faster approximate Non-Max-Suppression on numpy corners shaped: + 3xN [x_i,y_i,conf_i]^T + + Algo summary: Create a grid sized HxW. Assign each corner location a 1, + rest are zeros. Iterate through all the 1's and convert them to -1 or 0. + Suppress points by setting nearby values to 0. + + Grid Value Legend: + -1 : Kept. + 0 : Empty or suppressed. + 1 : To be processed (converted to either kept or supressed). + + NOTE: The NMS first rounds points to integers, so NMS distance might not + be exactly dist_thresh. It also assumes points are within image boundary. + + Inputs + in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T. + H - Image height. + W - Image width. + dist_thresh - Distance to suppress, measured as an infinite distance. + Returns + nmsed_corners - 3xN numpy matrix with surviving corners. + nmsed_inds - N length numpy vector with surviving corner indices. + """ + grid = np.zeros((H, W)).astype(int) # Track NMS data. + inds = np.zeros((H, W)).astype(int) # Store indices of points. + # Sort by confidence and round to nearest int. + inds1 = np.argsort(-in_corners[2, :]) + corners = in_corners[:, inds1] + rcorners = corners[:2, :].round().astype(int) # Rounded corners. + # Check for edge case of 0 or 1 corners. + if rcorners.shape[1] == 0: + return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int) + if rcorners.shape[1] == 1: + out = np.vstack((rcorners, in_corners[2])).reshape(3, 1) + return out, np.zeros((1)).astype(int) + # Initialize the grid. + for i, rc in enumerate(rcorners.T): + grid[rcorners[1, i], rcorners[0, i]] = 1 + inds[rcorners[1, i], rcorners[0, i]] = i + # Pad the border of the grid, so that we can NMS points near the border. + pad = dist_thresh + grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant') + # Iterate through points, highest to lowest conf, suppress neighborhood. + count = 0 + for i, rc in enumerate(rcorners.T): + # Account for top and left padding. + pt = (rc[0] + pad, rc[1] + pad) + if grid[pt[1], pt[0]] == 1: # If not yet suppressed. + grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0 + grid[pt[1], pt[0]] = -1 + count += 1 + # Get all surviving -1's and return sorted array of remaining corners. + keepy, keepx = np.where(grid == -1) + keepy, keepx = keepy - pad, keepx - pad + inds_keep = inds[keepy, keepx] + out = corners[:, inds_keep] + values = out[-1, :] + inds2 = np.argsort(-values) + out = out[:, inds2] + out_inds = inds1[inds_keep[inds2]] + return out, out_inds diff --git a/imcui/third_party/SOLD2/sold2/model/model_util.py b/imcui/third_party/SOLD2/sold2/model/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f70d80da40a72c207edfcfc1509e820846f0b731 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/model_util.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + +from .nets.backbone import HourglassBackbone, SuperpointBackbone +from .nets.junction_decoder import SuperpointDecoder +from .nets.heatmap_decoder import PixelShuffleDecoder +from .nets.descriptor_decoder import SuperpointDescriptor + + +def get_model(model_cfg=None, loss_weights=None, mode="train"): + """ Get model based on the model configuration. """ + # Check dataset config is given + if model_cfg is None: + raise ValueError("[Error] The model config is required!") + + # List the supported options here + print("\n\n\t--------Initializing model----------") + supported_arch = ["simple"] + if not model_cfg["model_architecture"] in supported_arch: + raise ValueError( + "[Error] The model architecture is not in supported arch!") + + if model_cfg["model_architecture"] == "simple": + model = SOLD2Net(model_cfg) + else: + raise ValueError( + "[Error] The model architecture is not in supported arch!") + + # Optionally register loss weights to the model + if mode == "train": + if loss_weights is not None: + for param_name, param in loss_weights.items(): + if isinstance(param, nn.Parameter): + print("\t [Debug] Adding %s with value %f to model" + % (param_name, param.item())) + model.register_parameter(param_name, param) + else: + raise ValueError( + "[Error] the loss weights can not be None in dynamic weighting mode during training.") + + # Display some summary info. + print("\tModel architecture: %s" % model_cfg["model_architecture"]) + print("\tBackbone: %s" % model_cfg["backbone"]) + print("\tJunction decoder: %s" % model_cfg["junction_decoder"]) + print("\tHeatmap decoder: %s" % model_cfg["heatmap_decoder"]) + print("\t-------------------------------------") + + return model + + +class SOLD2Net(nn.Module): + """ Full network for SOLD². """ + def __init__(self, model_cfg): + super(SOLD2Net, self).__init__() + self.name = model_cfg["model_name"] + self.cfg = model_cfg + + # List supported network options + self.supported_backbone = ["lcnn", "superpoint"] + self.backbone_net, self.feat_channel = self.get_backbone() + + # List supported junction decoder options + self.supported_junction_decoder = ["superpoint_decoder"] + self.junction_decoder = self.get_junction_decoder() + + # List supported heatmap decoder options + self.supported_heatmap_decoder = ["pixel_shuffle", + "pixel_shuffle_single"] + self.heatmap_decoder = self.get_heatmap_decoder() + + # List supported descriptor decoder options + if "descriptor_decoder" in self.cfg: + self.supported_descriptor_decoder = ["superpoint_descriptor"] + self.descriptor_decoder = self.get_descriptor_decoder() + + # Initialize the model weights + self.apply(weight_init) + + def forward(self, input_images): + # The backbone + features = self.backbone_net(input_images) + + # junction decoder + junctions = self.junction_decoder(features) + + # heatmap decoder + heatmaps = self.heatmap_decoder(features) + + outputs = {"junctions": junctions, "heatmap": heatmaps} + + # Descriptor decoder + if "descriptor_decoder" in self.cfg: + outputs["descriptors"] = self.descriptor_decoder(features) + + return outputs + + def get_backbone(self): + """ Retrieve the backbone encoder network. """ + if not self.cfg["backbone"] in self.supported_backbone: + raise ValueError( + "[Error] The backbone selection is not supported.") + + # lcnn backbone (stacked hourglass) + if self.cfg["backbone"] == "lcnn": + backbone_cfg = self.cfg["backbone_cfg"] + backbone = HourglassBackbone(**backbone_cfg) + feat_channel = 256 + + elif self.cfg["backbone"] == "superpoint": + backbone_cfg = self.cfg["backbone_cfg"] + backbone = SuperpointBackbone() + feat_channel = 128 + + else: + raise ValueError( + "[Error] The backbone selection is not supported.") + + return backbone, feat_channel + + def get_junction_decoder(self): + """ Get the junction decoder. """ + if (not self.cfg["junction_decoder"] + in self.supported_junction_decoder): + raise ValueError( + "[Error] The junction decoder selection is not supported.") + + # superpoint decoder + if self.cfg["junction_decoder"] == "superpoint_decoder": + decoder = SuperpointDecoder(self.feat_channel, + self.cfg["backbone"]) + else: + raise ValueError( + "[Error] The junction decoder selection is not supported.") + + return decoder + + def get_heatmap_decoder(self): + """ Get the heatmap decoder. """ + if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder: + raise ValueError( + "[Error] The heatmap decoder selection is not supported.") + + # Pixel_shuffle decoder + if self.cfg["heatmap_decoder"] == "pixel_shuffle": + if self.cfg["backbone"] == "lcnn": + decoder = PixelShuffleDecoder(self.feat_channel, + num_upsample=2) + elif self.cfg["backbone"] == "superpoint": + decoder = PixelShuffleDecoder(self.feat_channel, + num_upsample=3) + else: + raise ValueError("[Error] Unknown backbone option.") + # Pixel_shuffle decoder with single channel output + elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single": + if self.cfg["backbone"] == "lcnn": + decoder = PixelShuffleDecoder( + self.feat_channel, num_upsample=2, output_channel=1) + elif self.cfg["backbone"] == "superpoint": + decoder = PixelShuffleDecoder( + self.feat_channel, num_upsample=3, output_channel=1) + else: + raise ValueError("[Error] Unknown backbone option.") + else: + raise ValueError( + "[Error] The heatmap decoder selection is not supported.") + + return decoder + + def get_descriptor_decoder(self): + """ Get the descriptor decoder. """ + if (not self.cfg["descriptor_decoder"] + in self.supported_descriptor_decoder): + raise ValueError( + "[Error] The descriptor decoder selection is not supported.") + + # SuperPoint descriptor + if self.cfg["descriptor_decoder"] == "superpoint_descriptor": + decoder = SuperpointDescriptor(self.feat_channel) + else: + raise ValueError( + "[Error] The descriptor decoder selection is not supported.") + + return decoder + + +def weight_init(m): + """ Weight initialization function. """ + # Conv2D + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight.data) + if m.bias is not None: + init.normal_(m.bias.data) + # Batchnorm + elif isinstance(m, nn.BatchNorm2d): + init.normal_(m.weight.data, mean=1, std=0.02) + init.constant_(m.bias.data, 0) + # Linear + elif isinstance(m, nn.Linear): + init.xavier_normal_(m.weight.data) + init.normal_(m.bias.data) + else: + pass diff --git a/imcui/third_party/SOLD2/sold2/model/nets/__init__.py b/imcui/third_party/SOLD2/sold2/model/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/sold2/model/nets/backbone.py b/imcui/third_party/SOLD2/sold2/model/nets/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..71f260aef108c77d54319cab7bc082c3c51112e7 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/nets/backbone.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn + +from .lcnn_hourglass import MultitaskHead, hg + + +class HourglassBackbone(nn.Module): + """ Hourglass backbone. """ + def __init__(self, input_channel=1, depth=4, num_stacks=2, + num_blocks=1, num_classes=5): + super(HourglassBackbone, self).__init__() + self.head = MultitaskHead + self.net = hg(**{ + "head": self.head, + "depth": depth, + "num_stacks": num_stacks, + "num_blocks": num_blocks, + "num_classes": num_classes, + "input_channels": input_channel + }) + + def forward(self, input_images): + return self.net(input_images)[1] + + +class SuperpointBackbone(nn.Module): + """ SuperPoint backbone. """ + def __init__(self): + super(SuperpointBackbone, self).__init__() + self.relu = torch.nn.ReLU(inplace=True) + self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4 = 64, 64, 128, 128 + # Shared Encoder. + self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, + stride=1, padding=1) + self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, + stride=1, padding=1) + self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, + stride=1, padding=1) + self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, + stride=1, padding=1) + self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, + stride=1, padding=1) + self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, + stride=1, padding=1) + self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, + stride=1, padding=1) + self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, + stride=1, padding=1) + + def forward(self, input_images): + # Shared Encoder. + x = self.relu(self.conv1a(input_images)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + return x diff --git a/imcui/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py b/imcui/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed4306fad764efab2c22ede9cae253c9b17d6c2 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn + + +class SuperpointDescriptor(nn.Module): + """ Descriptor decoder based on the SuperPoint arcihtecture. """ + def __init__(self, input_feat_dim=128): + super(SuperpointDescriptor, self).__init__() + self.relu = torch.nn.ReLU(inplace=True) + self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3, + stride=1, padding=1) + self.convPb = torch.nn.Conv2d(256, 128, kernel_size=1, + stride=1, padding=0) + + def forward(self, input_features): + feat = self.relu(self.convPa(input_features)) + semi = self.convPb(feat) + + return semi \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py b/imcui/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5157ca740c8c7e25f2183b2a3c1fefa813deca --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py @@ -0,0 +1,59 @@ +import torch.nn as nn + + +class PixelShuffleDecoder(nn.Module): + """ Pixel shuffle decoder. """ + def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2): + super(PixelShuffleDecoder, self).__init__() + # Get channel parameters + self.channel_conf = self.get_channel_conf(num_upsample) + + # Define the pixel shuffle + self.pixshuffle = nn.PixelShuffle(2) + + # Process the feature + self.conv_block_lst = [] + # The input block + self.conv_block_lst.append( + nn.Sequential( + nn.Conv2d(input_feat_dim, self.channel_conf[0], + kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(self.channel_conf[0]), + nn.ReLU(inplace=True) + )) + + # Intermediate block + for channel in self.channel_conf[1:-1]: + self.conv_block_lst.append( + nn.Sequential( + nn.Conv2d(channel, channel, kernel_size=3, + stride=1, padding=1), + nn.BatchNorm2d(channel), + nn.ReLU(inplace=True) + )) + + # Output block + self.conv_block_lst.append( + nn.Conv2d(self.channel_conf[-1], output_channel, + kernel_size=1, stride=1, padding=0) + ) + self.conv_block_lst = nn.ModuleList(self.conv_block_lst) + + # Get num of channels based on number of upsampling. + def get_channel_conf(self, num_upsample): + if num_upsample == 2: + return [256, 64, 16] + elif num_upsample == 3: + return [256, 64, 16, 4] + + def forward(self, input_features): + # Iterate til output block + out = input_features + for block in self.conv_block_lst[:-1]: + out = block(out) + out = self.pixshuffle(out) + + # Output layer + out = self.conv_block_lst[-1](out) + + return out diff --git a/imcui/third_party/SOLD2/sold2/model/nets/junction_decoder.py b/imcui/third_party/SOLD2/sold2/model/nets/junction_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bb649518896501c784940028a772d688c2b3a7 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/nets/junction_decoder.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + + +class SuperpointDecoder(nn.Module): + """ Junction decoder based on the SuperPoint architecture. """ + def __init__(self, input_feat_dim=128, backbone_name="lcnn"): + super(SuperpointDecoder, self).__init__() + self.relu = torch.nn.ReLU(inplace=True) + # Perform strided convolution when using lcnn backbone. + if backbone_name == "lcnn": + self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3, + stride=2, padding=1) + elif backbone_name == "superpoint": + self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3, + stride=1, padding=1) + else: + raise ValueError("[Error] Unknown backbone option.") + + self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, + stride=1, padding=0) + + def forward(self, input_features): + feat = self.relu(self.convPa(input_features)) + semi = self.convPb(feat) + + return semi \ No newline at end of file diff --git a/imcui/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py b/imcui/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py new file mode 100644 index 0000000000000000000000000000000000000000..a9dc78eef34e7ee146166b1b66c10070799d63f3 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py @@ -0,0 +1,226 @@ +""" +Hourglass network, taken from https://github.com/zhou13/lcnn +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["HourglassNet", "hg"] + + +class MultitaskHead(nn.Module): + def __init__(self, input_channels, num_class): + super(MultitaskHead, self).__init__() + + m = int(input_channels / 4) + head_size = [[2], [1], [2]] + heads = [] + for output_channels in sum(head_size, []): + heads.append( + nn.Sequential( + nn.Conv2d(input_channels, m, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(m, output_channels, kernel_size=1), + ) + ) + self.heads = nn.ModuleList(heads) + assert num_class == sum(sum(head_size, [])) + + def forward(self, x): + return torch.cat([head(x) for head in self.heads], dim=1) + + +class Bottleneck2D(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck2D, self).__init__() + + self.bn1 = nn.BatchNorm2d(inplanes) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1) + self.bn3 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.relu(out) + out = self.conv1(out) + + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + + out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + +class Hourglass(nn.Module): + def __init__(self, block, num_blocks, planes, depth): + super(Hourglass, self).__init__() + self.depth = depth + self.block = block + self.hg = self._make_hour_glass(block, num_blocks, planes, depth) + + def _make_residual(self, block, num_blocks, planes): + layers = [] + for i in range(0, num_blocks): + layers.append(block(planes * block.expansion, planes)) + return nn.Sequential(*layers) + + def _make_hour_glass(self, block, num_blocks, planes, depth): + hg = [] + for i in range(depth): + res = [] + for j in range(3): + res.append(self._make_residual(block, num_blocks, planes)) + if i == 0: + res.append(self._make_residual(block, num_blocks, planes)) + hg.append(nn.ModuleList(res)) + return nn.ModuleList(hg) + + def _hour_glass_forward(self, n, x): + up1 = self.hg[n - 1][0](x) + low1 = F.max_pool2d(x, 2, stride=2) + low1 = self.hg[n - 1][1](low1) + + if n > 1: + low2 = self._hour_glass_forward(n - 1, low1) + else: + low2 = self.hg[n - 1][3](low1) + low3 = self.hg[n - 1][2](low2) + # up2 = F.interpolate(low3, scale_factor=2) + up2 = F.interpolate(low3, size=up1.shape[2:]) + out = up1 + up2 + return out + + def forward(self, x): + return self._hour_glass_forward(self.depth, x) + + +class HourglassNet(nn.Module): + """Hourglass model from Newell et al ECCV 2016""" + + def __init__(self, block, head, depth, num_stacks, num_blocks, + num_classes, input_channels): + super(HourglassNet, self).__init__() + + self.inplanes = 64 + self.num_feats = 128 + self.num_stacks = num_stacks + self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, + stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_residual(block, self.inplanes, 1) + self.layer2 = self._make_residual(block, self.inplanes, 1) + self.layer3 = self._make_residual(block, self.num_feats, 1) + self.maxpool = nn.MaxPool2d(2, stride=2) + + # build hourglass modules + ch = self.num_feats * block.expansion + # vpts = [] + hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] + for i in range(num_stacks): + hg.append(Hourglass(block, num_blocks, self.num_feats, depth)) + res.append(self._make_residual(block, self.num_feats, num_blocks)) + fc.append(self._make_fc(ch, ch)) + score.append(head(ch, num_classes)) + # vpts.append(VptsHead(ch)) + # vpts.append(nn.Linear(ch, 9)) + # score.append(nn.Conv2d(ch, num_classes, kernel_size=1)) + # score[i].bias.data[0] += 4.6 + # score[i].bias.data[2] += 4.6 + if i < num_stacks - 1: + fc_.append(nn.Conv2d(ch, ch, kernel_size=1)) + score_.append(nn.Conv2d(num_classes, ch, kernel_size=1)) + self.hg = nn.ModuleList(hg) + self.res = nn.ModuleList(res) + self.fc = nn.ModuleList(fc) + self.score = nn.ModuleList(score) + # self.vpts = nn.ModuleList(vpts) + self.fc_ = nn.ModuleList(fc_) + self.score_ = nn.ModuleList(score_) + + def _make_residual(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + ) + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_fc(self, inplanes, outplanes): + bn = nn.BatchNorm2d(inplanes) + conv = nn.Conv2d(inplanes, outplanes, kernel_size=1) + return nn.Sequential(conv, bn, self.relu) + + def forward(self, x): + out = [] + # out_vps = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.maxpool(x) + x = self.layer2(x) + x = self.layer3(x) + + for i in range(self.num_stacks): + y = self.hg[i](x) + y = self.res[i](y) + y = self.fc[i](y) + score = self.score[i](y) + # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1)) + # pre_vpts = pre_vpts.reshape(-1, 256) + # vpts = self.vpts[i](x) + out.append(score) + # out_vps.append(vpts) + if i < self.num_stacks - 1: + fc_ = self.fc_[i](y) + score_ = self.score_[i](score) + x = x + fc_ + score_ + + return out[::-1], y # , out_vps[::-1] + + +def hg(**kwargs): + model = HourglassNet( + Bottleneck2D, + head=kwargs.get("head", + lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)), + depth=kwargs["depth"], + num_stacks=kwargs["num_stacks"], + num_blocks=kwargs["num_blocks"], + num_classes=kwargs["num_classes"], + input_channels=kwargs["input_channels"] + ) + return model diff --git a/imcui/third_party/SOLD2/sold2/postprocess/__init__.py b/imcui/third_party/SOLD2/sold2/postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SOLD2/sold2/postprocess/convert_homography_results.py b/imcui/third_party/SOLD2/sold2/postprocess/convert_homography_results.py new file mode 100644 index 0000000000000000000000000000000000000000..352eebbde00f6d8a9c20517dccd7024fd0758ffd --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/postprocess/convert_homography_results.py @@ -0,0 +1,136 @@ +""" +Convert the aggregation results from the homography adaptation to GT labels. +""" +import sys +sys.path.append("../") +import os +import yaml +import argparse +import numpy as np +import h5py +import torch +from tqdm import tqdm + +from config.project_config import Config as cfg +from model.line_detection import LineSegmentDetectionModule +from model.metrics import super_nms +from misc.train_utils import parse_h5_data + + +def convert_raw_exported_predictions(input_data, grid_size=8, + detect_thresh=1/65, topk=300): + """ Convert the exported junctions and heatmaps predictions + to a standard format. + Arguments: + input_data: the raw data (dict) decoded from the hdf5 dataset + outputs: dict containing required entries including: + junctions_pred: Nx2 ndarray containing nms junction predictions. + heatmap_pred: HxW ndarray containing predicted heatmaps + valid_mask: HxW ndarray containing the valid mask + """ + # Check the input_data is from (1) single prediction, + # or (2) homography adaptation. + # Homography adaptation raw predictions + if (("junc_prob_mean" in input_data.keys()) + and ("heatmap_prob_mean" in input_data.keys())): + # Get the junction predictions and convert if to Nx2 format + junc_prob = input_data["junc_prob_mean"] + junc_pred_np = junc_prob[None, ...] + junc_pred_np_nms = super_nms(junc_pred_np, grid_size, + detect_thresh, topk) + junctions = np.where(junc_pred_np_nms.squeeze()) + junc_points_pred = np.concatenate([junctions[0][..., None], + junctions[1][..., None]], axis=-1) + + # Get the heatmap predictions + heatmap_pred = input_data["heatmap_prob_mean"].squeeze() + valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32) + + # Single predictions + else: + # Get the junction point predictions and convert to Nx2 format + junc_points_pred = np.where(input_data["junc_pred_nms"]) + junc_points_pred = np.concatenate( + [junc_points_pred[0][..., None], + junc_points_pred[1][..., None]], axis=-1) + + # Get the heatmap predictions + heatmap_pred = input_data["heatmap_pred"] + valid_mask = input_data["valid_mask"] + + return { + "junctions_pred": junc_points_pred, + "heatmap_pred": heatmap_pred, + "valid_mask": valid_mask + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_dataset", type=str, + help="Name of the exported dataset.") + parser.add_argument("output_dataset", type=str, + help="Name of the output dataset.") + parser.add_argument("config", type=str, + help="Path to the model config.") + args = parser.parse_args() + + # Define the path to the input exported dataset + exported_dataset_path = os.path.join(cfg.export_dataroot, + args.input_dataset) + if not os.path.exists(exported_dataset_path): + raise ValueError("Missing input dataset: " + exported_dataset_path) + exported_dataset = h5py.File(exported_dataset_path, "r") + + # Define the output path for the results + output_dataset_path = os.path.join(cfg.export_dataroot, + args.output_dataset) + + device = torch.device("cuda") + nms_device = torch.device("cuda") + + # Read the config file + if not os.path.exists(args.config): + raise ValueError("Missing config file: " + args.config) + with open(args.config, "r") as f: + config = yaml.safe_load(f) + model_cfg = config["model_cfg"] + line_detector_cfg = config["line_detector_cfg"] + + # Initialize the line detection module + line_detector = LineSegmentDetectionModule(**line_detector_cfg) + + # Iterate through all the dataset keys + with h5py.File(output_dataset_path, "w") as output_dataset: + for idx, output_key in enumerate(tqdm(list(exported_dataset.keys()), + ascii=True)): + # Get the data + data = parse_h5_data(exported_dataset[output_key]) + + # Preprocess the data + converted_data = convert_raw_exported_predictions( + data, grid_size=model_cfg["grid_size"], + detect_thresh=model_cfg["detection_thresh"]) + junctions_pred_raw = converted_data["junctions_pred"] + heatmap_pred = converted_data["heatmap_pred"] + valid_mask = converted_data["valid_mask"] + + line_map_pred, junctions_pred, heatmap_pred = line_detector.detect( + junctions_pred_raw, heatmap_pred, device=device) + if isinstance(line_map_pred, torch.Tensor): + line_map_pred = line_map_pred.cpu().numpy() + if isinstance(junctions_pred, torch.Tensor): + junctions_pred = junctions_pred.cpu().numpy() + if isinstance(heatmap_pred, torch.Tensor): + heatmap_pred = heatmap_pred.cpu().numpy() + + output_data = {"junctions": junctions_pred, + "line_map": line_map_pred} + + # Record it to the h5 dataset + f_group = output_dataset.create_group(output_key) + + # Store data + for key, output_data in output_data.items(): + f_group.create_dataset(key, data=output_data, + compression="gzip") diff --git a/imcui/third_party/SOLD2/sold2/train.py b/imcui/third_party/SOLD2/sold2/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2064e00e6d192f9202f011c3626d6f53c4fe6270 --- /dev/null +++ b/imcui/third_party/SOLD2/sold2/train.py @@ -0,0 +1,752 @@ +""" +This file implements the training process and all the summaries +""" +import os +import numpy as np +import cv2 +import torch +from torch.nn.functional import pixel_shuffle, softmax +from torch.utils.data import DataLoader +import torch.utils.data.dataloader as torch_loader +from tensorboardX import SummaryWriter + +from .dataset.dataset_util import get_dataset +from .model.model_util import get_model +from .model.loss import TotalLoss, get_loss_and_weights +from .model.metrics import AverageMeter, Metrics, super_nms +from .model.lr_scheduler import get_lr_scheduler +from .misc.train_utils import (convert_image, get_latest_checkpoint, + remove_old_checkpoints) + + +def customized_collate_fn(batch): + """ Customized collate_fn. """ + batch_keys = ["image", "junction_map", "heatmap", "valid_mask"] + list_keys = ["junctions", "line_map"] + + outputs = {} + for key in batch_keys: + outputs[key] = torch_loader.default_collate([b[key] for b in batch]) + for key in list_keys: + outputs[key] = [b[key] for b in batch] + + return outputs + + +def restore_weights(model, state_dict, strict=True): + """ Restore weights in compatible mode. """ + # Try to directly load state dict + try: + model.load_state_dict(state_dict, strict=strict) + # Deal with some version compatibility issue (catch version incompatible) + except: + err = model.load_state_dict(state_dict, strict=False) + + # missing keys are those in model but not in state_dict + missing_keys = err.missing_keys + # Unexpected keys are those in state_dict but not in model + unexpected_keys = err.unexpected_keys + + # Load mismatched keys manually + model_dict = model.state_dict() + for idx, key in enumerate(missing_keys): + dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] + model_dict[key] = state_dict[dict_keys[idx]] + model.load_state_dict(model_dict) + + return model + + +def train_net(args, dataset_cfg, model_cfg, output_path): + """ Main training function. """ + # Add some version compatibility check + if model_cfg.get("weighting_policy") is None: + # Default to static + model_cfg["weighting_policy"] = "static" + + # Get the train, val, test config + train_cfg = model_cfg["train"] + test_cfg = model_cfg["test"] + + # Create train and test dataset + print("\t Initializing dataset...") + train_dataset, train_collate_fn = get_dataset("train", dataset_cfg) + test_dataset, test_collate_fn = get_dataset("test", dataset_cfg) + + # Create the dataloader + train_loader = DataLoader(train_dataset, + batch_size=train_cfg["batch_size"], + num_workers=8, + shuffle=True, pin_memory=True, + collate_fn=train_collate_fn) + test_loader = DataLoader(test_dataset, + batch_size=test_cfg.get("batch_size", 1), + num_workers=test_cfg.get("num_workers", 1), + shuffle=False, pin_memory=False, + collate_fn=test_collate_fn) + print("\t Successfully intialized dataloaders.") + + + # Get the loss function and weight first + loss_funcs, loss_weights = get_loss_and_weights(model_cfg) + + # If resume. + if args.resume: + # Create model and load the state dict + checkpoint = get_latest_checkpoint(args.resume_path, + args.checkpoint_name) + model = get_model(model_cfg, loss_weights) + model = restore_weights(model, checkpoint["model_state_dict"]) + model = model.cuda() + optimizer = torch.optim.Adam( + [{"params": model.parameters(), + "initial_lr": model_cfg["learning_rate"]}], + model_cfg["learning_rate"], + amsgrad=True) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + # Optionally get the learning rate scheduler + scheduler = get_lr_scheduler( + lr_decay=model_cfg.get("lr_decay", False), + lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), + optimizer=optimizer) + # If we start to use learning rate scheduler from the middle + if ((scheduler is not None) + and (checkpoint.get("scheduler_state_dict", None) is not None)): + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + start_epoch = checkpoint["epoch"] + 1 + # Initialize all the components. + else: + # Create model and optimizer + model = get_model(model_cfg, loss_weights) + # Optionally get the pretrained wieghts + if args.pretrained: + print("\t [Debug] Loading pretrained weights...") + checkpoint = get_latest_checkpoint(args.pretrained_path, + args.checkpoint_name) + # If auto weighting restore from non-auto weighting + model = restore_weights(model, checkpoint["model_state_dict"], + strict=False) + print("\t [Debug] Finished loading pretrained weights!") + + model = model.cuda() + optimizer = torch.optim.Adam( + [{"params": model.parameters(), + "initial_lr": model_cfg["learning_rate"]}], + model_cfg["learning_rate"], + amsgrad=True) + # Optionally get the learning rate scheduler + scheduler = get_lr_scheduler( + lr_decay=model_cfg.get("lr_decay", False), + lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), + optimizer=optimizer) + start_epoch = 0 + + print("\t Successfully initialized model") + + # Define the total loss + policy = model_cfg.get("weighting_policy", "static") + loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda() + if "descriptor_decoder" in model_cfg: + metric_func = Metrics(model_cfg["detection_thresh"], + model_cfg["prob_thresh"], + model_cfg["descriptor_loss_cfg"]["grid_size"], + desc_metric_lst='all') + else: + metric_func = Metrics(model_cfg["detection_thresh"], + model_cfg["prob_thresh"], + model_cfg["grid_size"]) + + # Define the summary writer + logdir = os.path.join(output_path, "log") + writer = SummaryWriter(logdir=logdir) + + # Start the training loop + for epoch in range(start_epoch, model_cfg["epochs"]): + # Record the learning rate + current_lr = optimizer.state_dict()["param_groups"][0]["lr"] + writer.add_scalar("LR/lr", current_lr, epoch) + + # Train for one epochs + print("\n\n================== Training ====================") + train_single_epoch( + model=model, + model_cfg=model_cfg, + optimizer=optimizer, + loss_func=loss_func, + metric_func=metric_func, + train_loader=train_loader, + writer=writer, + epoch=epoch) + + # Do the validation + print("\n\n================== Validation ==================") + validate( + model=model, + model_cfg=model_cfg, + loss_func=loss_func, + metric_func=metric_func, + val_loader=test_loader, + writer=writer, + epoch=epoch) + + # Update the scheduler + if scheduler is not None: + scheduler.step() + + # Save checkpoints + file_name = os.path.join(output_path, + "checkpoint-epoch%03d-end.tar"%(epoch)) + print("[Info] Saving checkpoint %s ..." % file_name) + save_dict = { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "model_cfg": model_cfg} + if scheduler is not None: + save_dict.update({"scheduler_state_dict": scheduler.state_dict()}) + torch.save(save_dict, file_name) + + # Remove the outdated checkpoints + remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15)) + + +def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func, + train_loader, writer, epoch): + """ Train for one epoch. """ + # Switch the model to training mode + model.train() + + # Initialize the average meter + compute_descriptors = loss_func.compute_descriptors + if compute_descriptors: + average_meter = AverageMeter(is_training=True, desc_metric_lst='all') + else: + average_meter = AverageMeter(is_training=True) + + # The training loop + for idx, data in enumerate(train_loader): + if compute_descriptors: + junc_map = data["ref_junction_map"].cuda() + junc_map2 = data["target_junction_map"].cuda() + heatmap = data["ref_heatmap"].cuda() + heatmap2 = data["target_heatmap"].cuda() + line_points = data["ref_line_points"].cuda() + line_points2 = data["target_line_points"].cuda() + line_indices = data["ref_line_indices"].cuda() + valid_mask = data["ref_valid_mask"].cuda() + valid_mask2 = data["target_valid_mask"].cuda() + input_images = data["ref_image"].cuda() + input_images2 = data["target_image"].cuda() + + # Run the forward pass + outputs = model(input_images) + outputs2 = model(input_images2) + + # Compute losses + losses = loss_func.forward_descriptors( + outputs["junctions"], outputs2["junctions"], + junc_map, junc_map2, outputs["heatmap"], outputs2["heatmap"], + heatmap, heatmap2, line_points, line_points2, + line_indices, outputs['descriptors'], outputs2['descriptors'], + epoch, valid_mask, valid_mask2) + else: + junc_map = data["junction_map"].cuda() + heatmap = data["heatmap"].cuda() + valid_mask = data["valid_mask"].cuda() + input_images = data["image"].cuda() + + # Run the forward pass + outputs = model(input_images) + + # Compute losses + losses = loss_func( + outputs["junctions"], junc_map, + outputs["heatmap"], heatmap, + valid_mask) + + total_loss = losses["total_loss"] + + # Update the model + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + # Compute the global step + global_step = epoch * len(train_loader) + idx + ############## Measure the metric error ######################### + # Only do this when needed + if (((idx % model_cfg["disp_freq"]) == 0) + or ((idx % model_cfg["summary_freq"]) == 0)): + junc_np = convert_junc_predictions( + outputs["junctions"], model_cfg["grid_size"], + model_cfg["detection_thresh"], 300) + junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) + + # Always fetch only one channel (compatible with L1, L2, and CE) + if outputs["heatmap"].shape[1] == 2: + heatmap_np = softmax(outputs["heatmap"].detach(), + dim=1).cpu().numpy() + heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:] + else: + heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) + heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) + + heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) + valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) + + # Evaluate metric results + if compute_descriptors: + metric_func.evaluate( + junc_np["junc_pred"], junc_np["junc_pred_nms"], + junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, + line_points, line_points2, outputs["descriptors"], + outputs2["descriptors"], line_indices) + else: + metric_func.evaluate( + junc_np["junc_pred"], junc_np["junc_pred_nms"], + junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np) + # Update average meter + junc_loss = losses["junc_loss"].item() + heatmap_loss = losses["heatmap_loss"].item() + loss_dict = { + "junc_loss": junc_loss, + "heatmap_loss": heatmap_loss, + "total_loss": total_loss.item()} + if compute_descriptors: + descriptor_loss = losses["descriptor_loss"].item() + loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() + + average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) + + # Display the progress + if (idx % model_cfg["disp_freq"]) == 0: + results = metric_func.metric_results + average = average_meter.average() + # Get gpu memory usage in GB + gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) + if compute_descriptors: + print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB" + % (epoch, model_cfg["epochs"], idx, len(train_loader), + total_loss.item(), average["total_loss"], junc_loss, + average["junc_loss"], heatmap_loss, + average["heatmap_loss"], descriptor_loss, + average["descriptor_loss"], gpu_mem_usage)) + else: + print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB" + % (epoch, model_cfg["epochs"], idx, len(train_loader), + total_loss.item(), average["total_loss"], + junc_loss, average["junc_loss"], heatmap_loss, + average["heatmap_loss"], gpu_mem_usage)) + print("\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % (results["junc_precision"], average["junc_precision"], + results["junc_recall"], average["junc_recall"])) + print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % (results["junc_precision_nms"], + average["junc_precision_nms"], + results["junc_recall_nms"], average["junc_recall_nms"])) + print("\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" + %(results["heatmap_precision"], + average["heatmap_precision"], + results["heatmap_recall"], average["heatmap_recall"])) + if compute_descriptors: + print("\t Descriptors matching score=%.4f (%.4f)" + %(results["matching_score"], average["matching_score"])) + + # Record summaries + if (idx % model_cfg["summary_freq"]) == 0: + results = metric_func.metric_results + average = average_meter.average() + # Add the shared losses + scalar_summaries = { + "junc_loss": junc_loss, + "heatmap_loss": heatmap_loss, + "total_loss": total_loss.detach().cpu().numpy(), + "metrics": results, + "average": average} + # Add descriptor terms + if compute_descriptors: + scalar_summaries["descriptor_loss"] = descriptor_loss + scalar_summaries["w_desc"] = losses["w_desc"] + + # Add weighting terms (even for static terms) + scalar_summaries["w_junc"] = losses["w_junc"] + scalar_summaries["w_heatmap"] = losses["w_heatmap"] + scalar_summaries["reg_loss"] = losses["reg_loss"].item() + + num_images = 3 + junc_pred_binary = (junc_np["junc_pred"][:num_images, ...] + > model_cfg["detection_thresh"]) + junc_pred_nms_binary = (junc_np["junc_pred_nms"][:num_images, ...] + > model_cfg["detection_thresh"]) + image_summaries = { + "image": input_images.cpu().numpy()[:num_images, ...], + "valid_mask": valid_mask_np[:num_images, ...], + "junc_map_pred": junc_pred_binary, + "junc_map_pred_nms": junc_pred_nms_binary, + "junc_map_gt": junc_map_np[:num_images, ...], + "junc_prob_map": junc_np["junc_prob"][:num_images, ...], + "heatmap_pred": heatmap_np[:num_images, ...], + "heatmap_gt": heatmap_gt_np[:num_images, ...]} + # Record the training summary + record_train_summaries( + writer, global_step, scalars=scalar_summaries, + images=image_summaries) + + +def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch): + """ Validation. """ + # Switch the model to eval mode + model.eval() + + # Initialize the average meter + compute_descriptors = loss_func.compute_descriptors + if compute_descriptors: + average_meter = AverageMeter(is_training=True, desc_metric_lst='all') + else: + average_meter = AverageMeter(is_training=True) + + # The validation loop + for idx, data in enumerate(val_loader): + if compute_descriptors: + junc_map = data["ref_junction_map"].cuda() + junc_map2 = data["target_junction_map"].cuda() + heatmap = data["ref_heatmap"].cuda() + heatmap2 = data["target_heatmap"].cuda() + line_points = data["ref_line_points"].cuda() + line_points2 = data["target_line_points"].cuda() + line_indices = data["ref_line_indices"].cuda() + valid_mask = data["ref_valid_mask"].cuda() + valid_mask2 = data["target_valid_mask"].cuda() + input_images = data["ref_image"].cuda() + input_images2 = data["target_image"].cuda() + + # Run the forward pass + with torch.no_grad(): + outputs = model(input_images) + outputs2 = model(input_images2) + + # Compute losses + losses = loss_func.forward_descriptors( + outputs["junctions"], outputs2["junctions"], + junc_map, junc_map2, outputs["heatmap"], + outputs2["heatmap"], heatmap, heatmap2, line_points, + line_points2, line_indices, outputs['descriptors'], + outputs2['descriptors'], epoch, valid_mask, valid_mask2) + else: + junc_map = data["junction_map"].cuda() + heatmap = data["heatmap"].cuda() + valid_mask = data["valid_mask"].cuda() + input_images = data["image"].cuda() + + # Run the forward pass + with torch.no_grad(): + outputs = model(input_images) + + # Compute losses + losses = loss_func( + outputs["junctions"], junc_map, + outputs["heatmap"], heatmap, + valid_mask) + total_loss = losses["total_loss"] + + ############## Measure the metric error ######################### + junc_np = convert_junc_predictions( + outputs["junctions"], model_cfg["grid_size"], + model_cfg["detection_thresh"], 300) + junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) + # Always fetch only one channel (compatible with L1, L2, and CE) + if outputs["heatmap"].shape[1] == 2: + heatmap_np = softmax(outputs["heatmap"].detach(), + dim=1).cpu().numpy().transpose(0, 2, 3, 1) + heatmap_np = heatmap_np[:, :, :, 1:] + else: + heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) + heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) + + + heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) + valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) + + # Evaluate metric results + if compute_descriptors: + metric_func.evaluate( + junc_np["junc_pred"], junc_np["junc_pred_nms"], + junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np, + line_points, line_points2, outputs["descriptors"], + outputs2["descriptors"], line_indices) + else: + metric_func.evaluate( + junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np, + heatmap_np, heatmap_gt_np, valid_mask_np) + # Update average meter + junc_loss = losses["junc_loss"].item() + heatmap_loss = losses["heatmap_loss"].item() + loss_dict = { + "junc_loss": junc_loss, + "heatmap_loss": heatmap_loss, + "total_loss": total_loss.item()} + if compute_descriptors: + descriptor_loss = losses["descriptor_loss"].item() + loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() + average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) + + # Display the progress + if (idx % model_cfg["disp_freq"]) == 0: + results = metric_func.metric_results + average = average_meter.average() + if compute_descriptors: + print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)" + % (idx, len(val_loader), + total_loss.item(), average["total_loss"], + junc_loss, average["junc_loss"], + heatmap_loss, average["heatmap_loss"], + descriptor_loss, average["descriptor_loss"])) + else: + print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)" + % (idx, len(val_loader), + total_loss.item(), average["total_loss"], + junc_loss, average["junc_loss"], + heatmap_loss, average["heatmap_loss"])) + print("\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % (results["junc_precision"], average["junc_precision"], + results["junc_recall"], average["junc_recall"])) + print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % (results["junc_precision_nms"], + average["junc_precision_nms"], + results["junc_recall_nms"], average["junc_recall_nms"])) + print("\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" + % (results["heatmap_precision"], + average["heatmap_precision"], + results["heatmap_recall"], average["heatmap_recall"])) + if compute_descriptors: + print("\t Descriptors matching score=%.4f (%.4f)" + %(results["matching_score"], average["matching_score"])) + + # Record summaries + average = average_meter.average() + scalar_summaries = {"average": average} + # Record the training summary + record_test_summaries(writer, epoch, scalar_summaries) + + +def convert_junc_predictions(predictions, grid_size, + detect_thresh=1/65, topk=300): + """ Convert torch predictions to numpy arrays for evaluation. """ + # Convert to probability outputs first + junc_prob = softmax(predictions.detach(), dim=1).cpu() + junc_pred = junc_prob[:, :-1, :, :] + + junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1] + junc_prob_np = np.sum(junc_prob_np, axis=-1) + junc_pred_np = pixel_shuffle( + junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1) + junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) + junc_pred_np = junc_pred_np.squeeze(-1) + + return {"junc_pred": junc_pred_np, "junc_pred_nms": junc_pred_np_nms, + "junc_prob": junc_prob_np} + + +def record_train_summaries(writer, global_step, scalars, images): + """ Record training summaries. """ + # Record the scalar summaries + results = scalars["metrics"] + average = scalars["average"] + + # GPU memory part + # Get gpu memory usage in GB + gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) + writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step) + + # Loss part + writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], + global_step) + writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], + global_step) + writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], + global_step) + # Add regularization loss + if "reg_loss" in scalars.keys(): + writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], + global_step) + # Add descriptor loss + if "descriptor_loss" in scalars.keys(): + key = "descriptor_loss" + writer.add_scalar("Train_loss/%s"%(key), scalars[key], global_step) + writer.add_scalar("Train_loss_average/%s"%(key), average[key], + global_step) + + # Record weighting + for key in scalars.keys(): + if "w_" in key: + writer.add_scalar("Train_weight/%s"%(key), scalars[key], + global_step) + + # Smoothed loss + writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], + global_step) + writer.add_scalar("Train_loss_average/heatmap_loss", + average["heatmap_loss"], global_step) + writer.add_scalar("Train_loss_average/total_loss", average["total_loss"], + global_step) + # Add smoothed descriptor loss + if "descriptor_loss" in average.keys(): + writer.add_scalar("Train_loss_average/descriptor_loss", + average["descriptor_loss"], global_step) + + # Metrics part + writer.add_scalar("Train_metrics/junc_precision", + results["junc_precision"], global_step) + writer.add_scalar("Train_metrics/junc_precision_nms", + results["junc_precision_nms"], global_step) + writer.add_scalar("Train_metrics/junc_recall", + results["junc_recall"], global_step) + writer.add_scalar("Train_metrics/junc_recall_nms", + results["junc_recall_nms"], global_step) + writer.add_scalar("Train_metrics/heatmap_precision", + results["heatmap_precision"], global_step) + writer.add_scalar("Train_metrics/heatmap_recall", + results["heatmap_recall"], global_step) + # Add descriptor metric + if "matching_score" in results.keys(): + writer.add_scalar("Train_metrics/matching_score", + results["matching_score"], global_step) + + # Average part + writer.add_scalar("Train_metrics_average/junc_precision", + average["junc_precision"], global_step) + writer.add_scalar("Train_metrics_average/junc_precision_nms", + average["junc_precision_nms"], global_step) + writer.add_scalar("Train_metrics_average/junc_recall", + average["junc_recall"], global_step) + writer.add_scalar("Train_metrics_average/junc_recall_nms", + average["junc_recall_nms"], global_step) + writer.add_scalar("Train_metrics_average/heatmap_precision", + average["heatmap_precision"], global_step) + writer.add_scalar("Train_metrics_average/heatmap_recall", + average["heatmap_recall"], global_step) + # Add smoothed descriptor metric + if "matching_score" in average.keys(): + writer.add_scalar("Train_metrics_average/matching_score", + average["matching_score"], global_step) + + # Record the image summary + # Image part + image_tensor = convert_image(images["image"], 1) + valid_masks = convert_image(images["valid_mask"], -1) + writer.add_images("Train/images", image_tensor, global_step, + dataformats="NCHW") + writer.add_images("Train/valid_map", valid_masks, global_step, + dataformats="NHWC") + + # Heatmap part + writer.add_images("Train/heatmap_gt", + convert_image(images["heatmap_gt"], -1), global_step, + dataformats="NHWC") + writer.add_images("Train/heatmap_pred", + convert_image(images["heatmap_pred"], -1), global_step, + dataformats="NHWC") + + # Junction prediction part + junc_plots = plot_junction_detection( + image_tensor, images["junc_map_pred"], + images["junc_map_pred_nms"], images["junc_map_gt"]) + writer.add_images("Train/junc_gt", junc_plots["junc_gt_plot"] / 255., + global_step, dataformats="NHWC") + writer.add_images("Train/junc_pred", junc_plots["junc_pred_plot"] / 255., + global_step, dataformats="NHWC") + writer.add_images("Train/junc_pred_nms", + junc_plots["junc_pred_nms_plot"] / 255., global_step, + dataformats="NHWC") + writer.add_images( + "Train/junc_prob_map", + convert_image(images["junc_prob_map"][..., None], axis=-1), + global_step, dataformats="NHWC") + + +def record_test_summaries(writer, epoch, scalars): + """ Record testing summaries. """ + average = scalars["average"] + + # Average loss + writer.add_scalar("Val_loss/junc_loss", average["junc_loss"], epoch) + writer.add_scalar("Val_loss/heatmap_loss", average["heatmap_loss"], epoch) + writer.add_scalar("Val_loss/total_loss", average["total_loss"], epoch) + # Add descriptor loss + if "descriptor_loss" in average.keys(): + key = "descriptor_loss" + writer.add_scalar("Val_loss/%s"%(key), average[key], epoch) + + # Average metrics + writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], + epoch) + writer.add_scalar("Val_metrics/junc_precision_nms", + average["junc_precision_nms"], epoch) + writer.add_scalar("Val_metrics/junc_recall", + average["junc_recall"], epoch) + writer.add_scalar("Val_metrics/junc_recall_nms", + average["junc_recall_nms"], epoch) + writer.add_scalar("Val_metrics/heatmap_precision", + average["heatmap_precision"], epoch) + writer.add_scalar("Val_metrics/heatmap_recall", + average["heatmap_recall"], epoch) + # Add descriptor metric + if "matching_score" in average.keys(): + writer.add_scalar("Val_metrics/matching_score", + average["matching_score"], epoch) + + +def plot_junction_detection(image_tensor, junc_pred_tensor, + junc_pred_nms_tensor, junc_gt_tensor): + """ Plot the junction points on images. """ + # Get the batch_size + batch_size = image_tensor.shape[0] + + # Process through batch dimension + junc_pred_lst = [] + junc_pred_nms_lst = [] + junc_gt_lst = [] + for i in range(batch_size): + # Convert image to 255 uint8 + image = (image_tensor[i, :, :, :] + * 255.).astype(np.uint8).transpose(1,2,0) + + # Plot groundtruth onto image + junc_gt = junc_gt_tensor[i, ...] + coord_gt = np.where(junc_gt.squeeze() > 0) + points_gt = np.concatenate((coord_gt[0][..., None], + coord_gt[1][..., None]), + axis=1) + plot_gt = image.copy() + for id in range(points_gt.shape[0]): + cv2.circle(plot_gt, tuple(np.flip(points_gt[id, :])), 3, + color=(255, 0, 0), thickness=2) + junc_gt_lst.append(plot_gt[None, ...]) + + # Plot junc_pred + junc_pred = junc_pred_tensor[i, ...] + coord_pred = np.where(junc_pred > 0) + points_pred = np.concatenate((coord_pred[0][..., None], + coord_pred[1][..., None]), + axis=1) + plot_pred = image.copy() + for id in range(points_pred.shape[0]): + cv2.circle(plot_pred, tuple(np.flip(points_pred[id, :])), 3, + color=(0, 255, 0), thickness=2) + junc_pred_lst.append(plot_pred[None, ...]) + + # Plot junc_pred_nms + junc_pred_nms = junc_pred_nms_tensor[i, ...] + coord_pred_nms = np.where(junc_pred_nms > 0) + points_pred_nms = np.concatenate((coord_pred_nms[0][..., None], + coord_pred_nms[1][..., None]), + axis=1) + plot_pred_nms = image.copy() + for id in range(points_pred_nms.shape[0]): + cv2.circle(plot_pred_nms, tuple(np.flip(points_pred_nms[id, :])), + 3, color=(0, 255, 0), thickness=2) + junc_pred_nms_lst.append(plot_pred_nms[None, ...]) + + return {"junc_gt_plot": np.concatenate(junc_gt_lst, axis=0), + "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0), + "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0)} diff --git a/imcui/third_party/SuperGluePretrainedNetwork/demo_superglue.py b/imcui/third_party/SuperGluePretrainedNetwork/demo_superglue.py new file mode 100644 index 0000000000000000000000000000000000000000..32d4ad3c7df1b7da141c4c6aa51f871a7d756aaf --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/demo_superglue.py @@ -0,0 +1,259 @@ +#! /usr/bin/env python3 +# +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# Daniel DeTone +# Tomasz Malisiewicz +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +from pathlib import Path +import argparse +import cv2 +import matplotlib.cm as cm +import torch + +from models.matching import Matching +from models.utils import (AverageTimer, VideoStreamer, + make_matching_plot_fast, frame2tensor) + +torch.set_grad_enabled(False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='SuperGlue demo', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--input', type=str, default='0', + help='ID of a USB webcam, URL of an IP camera, ' + 'or path to an image directory or movie file') + parser.add_argument( + '--output_dir', type=str, default=None, + help='Directory where to write output frames (If None, no output)') + + parser.add_argument( + '--image_glob', type=str, nargs='+', default=['*.png', '*.jpg', '*.jpeg'], + help='Glob if a directory of images is specified') + parser.add_argument( + '--skip', type=int, default=1, + help='Images to skip if input is a movie or directory') + parser.add_argument( + '--max_length', type=int, default=1000000, + help='Maximum length if input is a movie or directory') + parser.add_argument( + '--resize', type=int, nargs='+', default=[640, 480], + help='Resize the input image before running inference. If two numbers, ' + 'resize to the exact dimensions, if one number, resize the max ' + 'dimension, if -1, do not resize') + + parser.add_argument( + '--superglue', choices={'indoor', 'outdoor'}, default='indoor', + help='SuperGlue weights') + parser.add_argument( + '--max_keypoints', type=int, default=-1, + help='Maximum number of keypoints detected by Superpoint' + ' (\'-1\' keeps all keypoints)') + parser.add_argument( + '--keypoint_threshold', type=float, default=0.005, + help='SuperPoint keypoint detector confidence threshold') + parser.add_argument( + '--nms_radius', type=int, default=4, + help='SuperPoint Non Maximum Suppression (NMS) radius' + ' (Must be positive)') + parser.add_argument( + '--sinkhorn_iterations', type=int, default=20, + help='Number of Sinkhorn iterations performed by SuperGlue') + parser.add_argument( + '--match_threshold', type=float, default=0.2, + help='SuperGlue match threshold') + + parser.add_argument( + '--show_keypoints', action='store_true', + help='Show the detected keypoints') + parser.add_argument( + '--no_display', action='store_true', + help='Do not display images to screen. Useful if running remotely') + parser.add_argument( + '--force_cpu', action='store_true', + help='Force pytorch to run in CPU mode.') + + opt = parser.parse_args() + print(opt) + + if len(opt.resize) == 2 and opt.resize[1] == -1: + opt.resize = opt.resize[0:1] + if len(opt.resize) == 2: + print('Will resize to {}x{} (WxH)'.format( + opt.resize[0], opt.resize[1])) + elif len(opt.resize) == 1 and opt.resize[0] > 0: + print('Will resize max dimension to {}'.format(opt.resize[0])) + elif len(opt.resize) == 1: + print('Will not resize images') + else: + raise ValueError('Cannot specify more than two integers for --resize') + + device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu' + print('Running inference on device \"{}\"'.format(device)) + config = { + 'superpoint': { + 'nms_radius': opt.nms_radius, + 'keypoint_threshold': opt.keypoint_threshold, + 'max_keypoints': opt.max_keypoints + }, + 'superglue': { + 'weights': opt.superglue, + 'sinkhorn_iterations': opt.sinkhorn_iterations, + 'match_threshold': opt.match_threshold, + } + } + matching = Matching(config).eval().to(device) + keys = ['keypoints', 'scores', 'descriptors'] + + vs = VideoStreamer(opt.input, opt.resize, opt.skip, + opt.image_glob, opt.max_length) + frame, ret = vs.next_frame() + assert ret, 'Error when reading the first frame (try different --input?)' + + frame_tensor = frame2tensor(frame, device) + last_data = matching.superpoint({'image': frame_tensor}) + last_data = {k+'0': last_data[k] for k in keys} + last_data['image0'] = frame_tensor + last_frame = frame + last_image_id = 0 + + if opt.output_dir is not None: + print('==> Will write outputs to {}'.format(opt.output_dir)) + Path(opt.output_dir).mkdir(exist_ok=True) + + # Create a window to display the demo. + if not opt.no_display: + cv2.namedWindow('SuperGlue matches', cv2.WINDOW_NORMAL) + cv2.resizeWindow('SuperGlue matches', 640*2, 480) + else: + print('Skipping visualization, will not show a GUI.') + + # Print the keyboard help menu. + print('==> Keyboard control:\n' + '\tn: select the current frame as the anchor\n' + '\te/r: increase/decrease the keypoint confidence threshold\n' + '\td/f: increase/decrease the match filtering threshold\n' + '\tk: toggle the visualization of keypoints\n' + '\tq: quit') + + timer = AverageTimer() + + while True: + frame, ret = vs.next_frame() + if not ret: + print('Finished demo_superglue.py') + break + timer.update('data') + stem0, stem1 = last_image_id, vs.i - 1 + + frame_tensor = frame2tensor(frame, device) + pred = matching({**last_data, 'image1': frame_tensor}) + kpts0 = last_data['keypoints0'][0].cpu().numpy() + kpts1 = pred['keypoints1'][0].cpu().numpy() + matches = pred['matches0'][0].cpu().numpy() + confidence = pred['matching_scores0'][0].cpu().numpy() + timer.update('forward') + + valid = matches > -1 + mkpts0 = kpts0[valid] + mkpts1 = kpts1[matches[valid]] + color = cm.jet(confidence[valid]) + text = [ + 'SuperGlue', + 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)), + 'Matches: {}'.format(len(mkpts0)) + ] + k_thresh = matching.superpoint.config['keypoint_threshold'] + m_thresh = matching.superglue.config['match_threshold'] + small_text = [ + 'Keypoint Threshold: {:.4f}'.format(k_thresh), + 'Match Threshold: {:.2f}'.format(m_thresh), + 'Image Pair: {:06}:{:06}'.format(stem0, stem1), + ] + out = make_matching_plot_fast( + last_frame, frame, kpts0, kpts1, mkpts0, mkpts1, color, text, + path=None, show_keypoints=opt.show_keypoints, small_text=small_text) + + if not opt.no_display: + cv2.imshow('SuperGlue matches', out) + key = chr(cv2.waitKey(1) & 0xFF) + if key == 'q': + vs.cleanup() + print('Exiting (via q) demo_superglue.py') + break + elif key == 'n': # set the current frame as anchor + last_data = {k+'0': pred[k+'1'] for k in keys} + last_data['image0'] = frame_tensor + last_frame = frame + last_image_id = (vs.i - 1) + elif key in ['e', 'r']: + # Increase/decrease keypoint threshold by 10% each keypress. + d = 0.1 * (-1 if key == 'e' else 1) + matching.superpoint.config['keypoint_threshold'] = min(max( + 0.0001, matching.superpoint.config['keypoint_threshold']*(1+d)), 1) + print('\nChanged the keypoint threshold to {:.4f}'.format( + matching.superpoint.config['keypoint_threshold'])) + elif key in ['d', 'f']: + # Increase/decrease match threshold by 0.05 each keypress. + d = 0.05 * (-1 if key == 'd' else 1) + matching.superglue.config['match_threshold'] = min(max( + 0.05, matching.superglue.config['match_threshold']+d), .95) + print('\nChanged the match threshold to {:.2f}'.format( + matching.superglue.config['match_threshold'])) + elif key == 'k': + opt.show_keypoints = not opt.show_keypoints + + timer.update('viz') + timer.print() + + if opt.output_dir is not None: + #stem = 'matches_{:06}_{:06}'.format(last_image_id, vs.i-1) + stem = 'matches_{:06}_{:06}'.format(stem0, stem1) + out_file = str(Path(opt.output_dir, stem + '.png')) + print('\nWriting image to {}'.format(out_file)) + cv2.imwrite(out_file, out) + + cv2.destroyAllWindows() + vs.cleanup() diff --git a/imcui/third_party/SuperGluePretrainedNetwork/match_pairs.py b/imcui/third_party/SuperGluePretrainedNetwork/match_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..7079687cf69fd71d810ec80442548ad2a7b869e0 --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/match_pairs.py @@ -0,0 +1,425 @@ +#! /usr/bin/env python3 +# +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# Daniel DeTone +# Tomasz Malisiewicz +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +from pathlib import Path +import argparse +import random +import numpy as np +import matplotlib.cm as cm +import torch + + +from models.matching import Matching +from models.utils import (compute_pose_error, compute_epipolar_error, + estimate_pose, make_matching_plot, + error_colormap, AverageTimer, pose_auc, read_image, + rotate_intrinsics, rotate_pose_inplane, + scale_intrinsics) + +torch.set_grad_enabled(False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Image pair matching and pose evaluation with SuperGlue', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + '--input_pairs', type=str, default='assets/scannet_sample_pairs_with_gt.txt', + help='Path to the list of image pairs') + parser.add_argument( + '--input_dir', type=str, default='assets/scannet_sample_images/', + help='Path to the directory that contains the images') + parser.add_argument( + '--output_dir', type=str, default='dump_match_pairs/', + help='Path to the directory in which the .npz results and optionally,' + 'the visualization images are written') + + parser.add_argument( + '--max_length', type=int, default=-1, + help='Maximum number of pairs to evaluate') + parser.add_argument( + '--resize', type=int, nargs='+', default=[640, 480], + help='Resize the input image before running inference. If two numbers, ' + 'resize to the exact dimensions, if one number, resize the max ' + 'dimension, if -1, do not resize') + parser.add_argument( + '--resize_float', action='store_true', + help='Resize the image after casting uint8 to float') + + parser.add_argument( + '--superglue', choices={'indoor', 'outdoor'}, default='indoor', + help='SuperGlue weights') + parser.add_argument( + '--max_keypoints', type=int, default=1024, + help='Maximum number of keypoints detected by Superpoint' + ' (\'-1\' keeps all keypoints)') + parser.add_argument( + '--keypoint_threshold', type=float, default=0.005, + help='SuperPoint keypoint detector confidence threshold') + parser.add_argument( + '--nms_radius', type=int, default=4, + help='SuperPoint Non Maximum Suppression (NMS) radius' + ' (Must be positive)') + parser.add_argument( + '--sinkhorn_iterations', type=int, default=20, + help='Number of Sinkhorn iterations performed by SuperGlue') + parser.add_argument( + '--match_threshold', type=float, default=0.2, + help='SuperGlue match threshold') + + parser.add_argument( + '--viz', action='store_true', + help='Visualize the matches and dump the plots') + parser.add_argument( + '--eval', action='store_true', + help='Perform the evaluation' + ' (requires ground truth pose and intrinsics)') + parser.add_argument( + '--fast_viz', action='store_true', + help='Use faster image visualization with OpenCV instead of Matplotlib') + parser.add_argument( + '--cache', action='store_true', + help='Skip the pair if output .npz files are already found') + parser.add_argument( + '--show_keypoints', action='store_true', + help='Plot the keypoints in addition to the matches') + parser.add_argument( + '--viz_extension', type=str, default='png', choices=['png', 'pdf'], + help='Visualization file extension. Use pdf for highest-quality.') + parser.add_argument( + '--opencv_display', action='store_true', + help='Visualize via OpenCV before saving output images') + parser.add_argument( + '--shuffle', action='store_true', + help='Shuffle ordering of pairs before processing') + parser.add_argument( + '--force_cpu', action='store_true', + help='Force pytorch to run in CPU mode.') + + opt = parser.parse_args() + print(opt) + + assert not (opt.opencv_display and not opt.viz), 'Must use --viz with --opencv_display' + assert not (opt.opencv_display and not opt.fast_viz), 'Cannot use --opencv_display without --fast_viz' + assert not (opt.fast_viz and not opt.viz), 'Must use --viz with --fast_viz' + assert not (opt.fast_viz and opt.viz_extension == 'pdf'), 'Cannot use pdf extension with --fast_viz' + + if len(opt.resize) == 2 and opt.resize[1] == -1: + opt.resize = opt.resize[0:1] + if len(opt.resize) == 2: + print('Will resize to {}x{} (WxH)'.format( + opt.resize[0], opt.resize[1])) + elif len(opt.resize) == 1 and opt.resize[0] > 0: + print('Will resize max dimension to {}'.format(opt.resize[0])) + elif len(opt.resize) == 1: + print('Will not resize images') + else: + raise ValueError('Cannot specify more than two integers for --resize') + + with open(opt.input_pairs, 'r') as f: + pairs = [l.split() for l in f.readlines()] + + if opt.max_length > -1: + pairs = pairs[0:np.min([len(pairs), opt.max_length])] + + if opt.shuffle: + random.Random(0).shuffle(pairs) + + if opt.eval: + if not all([len(p) == 38 for p in pairs]): + raise ValueError( + 'All pairs should have ground truth info for evaluation.' + 'File \"{}\" needs 38 valid entries per row'.format(opt.input_pairs)) + + # Load the SuperPoint and SuperGlue models. + device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu' + print('Running inference on device \"{}\"'.format(device)) + config = { + 'superpoint': { + 'nms_radius': opt.nms_radius, + 'keypoint_threshold': opt.keypoint_threshold, + 'max_keypoints': opt.max_keypoints + }, + 'superglue': { + 'weights': opt.superglue, + 'sinkhorn_iterations': opt.sinkhorn_iterations, + 'match_threshold': opt.match_threshold, + } + } + matching = Matching(config).eval().to(device) + + # Create the output directories if they do not exist already. + input_dir = Path(opt.input_dir) + print('Looking for data in directory \"{}\"'.format(input_dir)) + output_dir = Path(opt.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + print('Will write matches to directory \"{}\"'.format(output_dir)) + if opt.eval: + print('Will write evaluation results', + 'to directory \"{}\"'.format(output_dir)) + if opt.viz: + print('Will write visualization images to', + 'directory \"{}\"'.format(output_dir)) + + timer = AverageTimer(newline=True) + for i, pair in enumerate(pairs): + name0, name1 = pair[:2] + stem0, stem1 = Path(name0).stem, Path(name1).stem + matches_path = output_dir / '{}_{}_matches.npz'.format(stem0, stem1) + eval_path = output_dir / '{}_{}_evaluation.npz'.format(stem0, stem1) + viz_path = output_dir / '{}_{}_matches.{}'.format(stem0, stem1, opt.viz_extension) + viz_eval_path = output_dir / \ + '{}_{}_evaluation.{}'.format(stem0, stem1, opt.viz_extension) + + # Handle --cache logic. + do_match = True + do_eval = opt.eval + do_viz = opt.viz + do_viz_eval = opt.eval and opt.viz + if opt.cache: + if matches_path.exists(): + try: + results = np.load(matches_path) + except: + raise IOError('Cannot load matches .npz file: %s' % + matches_path) + + kpts0, kpts1 = results['keypoints0'], results['keypoints1'] + matches, conf = results['matches'], results['match_confidence'] + do_match = False + if opt.eval and eval_path.exists(): + try: + results = np.load(eval_path) + except: + raise IOError('Cannot load eval .npz file: %s' % eval_path) + err_R, err_t = results['error_R'], results['error_t'] + precision = results['precision'] + matching_score = results['matching_score'] + num_correct = results['num_correct'] + epi_errs = results['epipolar_errors'] + do_eval = False + if opt.viz and viz_path.exists(): + do_viz = False + if opt.viz and opt.eval and viz_eval_path.exists(): + do_viz_eval = False + timer.update('load_cache') + + if not (do_match or do_eval or do_viz or do_viz_eval): + timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs))) + continue + + # If a rotation integer is provided (e.g. from EXIF data), use it: + if len(pair) >= 5: + rot0, rot1 = int(pair[2]), int(pair[3]) + else: + rot0, rot1 = 0, 0 + + # Load the image pair. + image0, inp0, scales0 = read_image( + input_dir / name0, device, opt.resize, rot0, opt.resize_float) + image1, inp1, scales1 = read_image( + input_dir / name1, device, opt.resize, rot1, opt.resize_float) + if image0 is None or image1 is None: + print('Problem reading image pair: {} {}'.format( + input_dir/name0, input_dir/name1)) + exit(1) + timer.update('load_image') + + if do_match: + # Perform the matching. + pred = matching({'image0': inp0, 'image1': inp1}) + pred = {k: v[0].cpu().numpy() for k, v in pred.items()} + kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] + matches, conf = pred['matches0'], pred['matching_scores0'] + timer.update('matcher') + + # Write the matches to disk. + out_matches = {'keypoints0': kpts0, 'keypoints1': kpts1, + 'matches': matches, 'match_confidence': conf} + np.savez(str(matches_path), **out_matches) + + # Keep the matching keypoints. + valid = matches > -1 + mkpts0 = kpts0[valid] + mkpts1 = kpts1[matches[valid]] + mconf = conf[valid] + + if do_eval: + # Estimate the pose and compute the pose error. + assert len(pair) == 38, 'Pair does not have ground truth info' + K0 = np.array(pair[4:13]).astype(float).reshape(3, 3) + K1 = np.array(pair[13:22]).astype(float).reshape(3, 3) + T_0to1 = np.array(pair[22:]).astype(float).reshape(4, 4) + + # Scale the intrinsics to resized image. + K0 = scale_intrinsics(K0, scales0) + K1 = scale_intrinsics(K1, scales1) + + # Update the intrinsics + extrinsics if EXIF rotation was found. + if rot0 != 0 or rot1 != 0: + cam0_T_w = np.eye(4) + cam1_T_w = T_0to1 + if rot0 != 0: + K0 = rotate_intrinsics(K0, image0.shape, rot0) + cam0_T_w = rotate_pose_inplane(cam0_T_w, rot0) + if rot1 != 0: + K1 = rotate_intrinsics(K1, image1.shape, rot1) + cam1_T_w = rotate_pose_inplane(cam1_T_w, rot1) + cam1_T_cam0 = cam1_T_w @ np.linalg.inv(cam0_T_w) + T_0to1 = cam1_T_cam0 + + epi_errs = compute_epipolar_error(mkpts0, mkpts1, T_0to1, K0, K1) + correct = epi_errs < 5e-4 + num_correct = np.sum(correct) + precision = np.mean(correct) if len(correct) > 0 else 0 + matching_score = num_correct / len(kpts0) if len(kpts0) > 0 else 0 + + thresh = 1. # In pixels relative to resized image size. + ret = estimate_pose(mkpts0, mkpts1, K0, K1, thresh) + if ret is None: + err_t, err_R = np.inf, np.inf + else: + R, t, inliers = ret + err_t, err_R = compute_pose_error(T_0to1, R, t) + + # Write the evaluation results to disk. + out_eval = {'error_t': err_t, + 'error_R': err_R, + 'precision': precision, + 'matching_score': matching_score, + 'num_correct': num_correct, + 'epipolar_errors': epi_errs} + np.savez(str(eval_path), **out_eval) + timer.update('eval') + + if do_viz: + # Visualize the matches. + color = cm.jet(mconf) + text = [ + 'SuperGlue', + 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)), + 'Matches: {}'.format(len(mkpts0)), + ] + if rot0 != 0 or rot1 != 0: + text.append('Rotation: {}:{}'.format(rot0, rot1)) + + # Display extra parameter info. + k_thresh = matching.superpoint.config['keypoint_threshold'] + m_thresh = matching.superglue.config['match_threshold'] + small_text = [ + 'Keypoint Threshold: {:.4f}'.format(k_thresh), + 'Match Threshold: {:.2f}'.format(m_thresh), + 'Image Pair: {}:{}'.format(stem0, stem1), + ] + + make_matching_plot( + image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, + text, viz_path, opt.show_keypoints, + opt.fast_viz, opt.opencv_display, 'Matches', small_text) + + timer.update('viz_match') + + if do_viz_eval: + # Visualize the evaluation results for the image pair. + color = np.clip((epi_errs - 0) / (1e-3 - 0), 0, 1) + color = error_colormap(1 - color) + deg, delta = ' deg', 'Delta ' + if not opt.fast_viz: + deg, delta = '°', '$\\Delta$' + e_t = 'FAIL' if np.isinf(err_t) else '{:.1f}{}'.format(err_t, deg) + e_R = 'FAIL' if np.isinf(err_R) else '{:.1f}{}'.format(err_R, deg) + text = [ + 'SuperGlue', + '{}R: {}'.format(delta, e_R), '{}t: {}'.format(delta, e_t), + 'inliers: {}/{}'.format(num_correct, (matches > -1).sum()), + ] + if rot0 != 0 or rot1 != 0: + text.append('Rotation: {}:{}'.format(rot0, rot1)) + + # Display extra parameter info (only works with --fast_viz). + k_thresh = matching.superpoint.config['keypoint_threshold'] + m_thresh = matching.superglue.config['match_threshold'] + small_text = [ + 'Keypoint Threshold: {:.4f}'.format(k_thresh), + 'Match Threshold: {:.2f}'.format(m_thresh), + 'Image Pair: {}:{}'.format(stem0, stem1), + ] + + make_matching_plot( + image0, image1, kpts0, kpts1, mkpts0, + mkpts1, color, text, viz_eval_path, + opt.show_keypoints, opt.fast_viz, + opt.opencv_display, 'Relative Pose', small_text) + + timer.update('viz_eval') + + timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs))) + + if opt.eval: + # Collate the results into a final table and print to terminal. + pose_errors = [] + precisions = [] + matching_scores = [] + for pair in pairs: + name0, name1 = pair[:2] + stem0, stem1 = Path(name0).stem, Path(name1).stem + eval_path = output_dir / \ + '{}_{}_evaluation.npz'.format(stem0, stem1) + results = np.load(eval_path) + pose_error = np.maximum(results['error_t'], results['error_R']) + pose_errors.append(pose_error) + precisions.append(results['precision']) + matching_scores.append(results['matching_score']) + thresholds = [5, 10, 20] + aucs = pose_auc(pose_errors, thresholds) + aucs = [100.*yy for yy in aucs] + prec = 100.*np.mean(precisions) + ms = 100.*np.mean(matching_scores) + print('Evaluation Results (mean over {} pairs):'.format(len(pairs))) + print('AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t') + print('{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t'.format( + aucs[0], aucs[1], aucs[2], prec, ms)) diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/__init__.py b/imcui/third_party/SuperGluePretrainedNetwork/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/matching.py b/imcui/third_party/SuperGluePretrainedNetwork/models/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..5d174208d146373230a8a68dd1420fc59c180633 --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/models/matching.py @@ -0,0 +1,84 @@ +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +import torch + +from .superpoint import SuperPoint +from .superglue import SuperGlue + + +class Matching(torch.nn.Module): + """ Image Matching Frontend (SuperPoint + SuperGlue) """ + def __init__(self, config={}): + super().__init__() + self.superpoint = SuperPoint(config.get('superpoint', {})) + self.superglue = SuperGlue(config.get('superglue', {})) + + def forward(self, data): + """ Run SuperPoint (optionally) and SuperGlue + SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input + Args: + data: dictionary with minimal keys: ['image0', 'image1'] + """ + pred = {} + + # Extract SuperPoint (keypoints, scores, descriptors) if not provided + if 'keypoints0' not in data: + pred0 = self.superpoint({'image': data['image0']}) + pred = {**pred, **{k+'0': v for k, v in pred0.items()}} + if 'keypoints1' not in data: + pred1 = self.superpoint({'image': data['image1']}) + pred = {**pred, **{k+'1': v for k, v in pred1.items()}} + + # Batch all features + # We should either have i) one image per batch, or + # ii) the same number of local features for all images in the batch. + data = {**data, **pred} + + for k in data: + if isinstance(data[k], (list, tuple)): + data[k] = torch.stack(data[k]) + + # Perform the matching + pred = {**pred, **self.superglue(data)} + + return pred diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/superglue.py b/imcui/third_party/SuperGluePretrainedNetwork/models/superglue.py new file mode 100644 index 0000000000000000000000000000000000000000..5a89b0348075bcb918eab123bc988c7102137a3d --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/models/superglue.py @@ -0,0 +1,285 @@ +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +from copy import deepcopy +from pathlib import Path +from typing import List, Tuple + +import torch +from torch import nn + + +def MLP(channels: List[int], do_bn: bool = True) -> nn.Module: + """ Multi-layer perceptron """ + n = len(channels) + layers = [] + for i in range(1, n): + layers.append( + nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) + if i < (n-1): + if do_bn: + layers.append(nn.BatchNorm1d(channels[i])) + layers.append(nn.ReLU()) + return nn.Sequential(*layers) + + +def normalize_keypoints(kpts, image_shape): + """ Normalize keypoints locations based on image image_shape""" + _, _, height, width = image_shape + one = kpts.new_tensor(1) + size = torch.stack([one*width, one*height])[None] + center = size / 2 + scaling = size.max(1, keepdim=True).values * 0.7 + return (kpts - center[:, None, :]) / scaling[:, None, :] + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + def __init__(self, feature_dim: int, layers: List[int]) -> None: + super().__init__() + self.encoder = MLP([3] + layers + [feature_dim]) + nn.init.constant_(self.encoder[-1].bias, 0.0) + + def forward(self, kpts, scores): + inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] + return self.encoder(torch.cat(inputs, dim=1)) + + +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]: + dim = query.shape[1] + scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 + prob = torch.nn.functional.softmax(scores, dim=-1) + return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob + + +class MultiHeadedAttention(nn.Module): + """ Multi-head attention to increase model expressivitiy """ + def __init__(self, num_heads: int, d_model: int): + super().__init__() + assert d_model % num_heads == 0 + self.dim = d_model // num_heads + self.num_heads = num_heads + self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) + self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + batch_dim = query.size(0) + query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) + for l, x in zip(self.proj, (query, key, value))] + x, _ = attention(query, key, value) + return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) + + +class AttentionalPropagation(nn.Module): + def __init__(self, feature_dim: int, num_heads: int): + super().__init__() + self.attn = MultiHeadedAttention(num_heads, feature_dim) + self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) + nn.init.constant_(self.mlp[-1].bias, 0.0) + + def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor: + message = self.attn(x, source, source) + return self.mlp(torch.cat([x, message], dim=1)) + + +class AttentionalGNN(nn.Module): + def __init__(self, feature_dim: int, layer_names: List[str]) -> None: + super().__init__() + self.layers = nn.ModuleList([ + AttentionalPropagation(feature_dim, 4) + for _ in range(len(layer_names))]) + self.names = layer_names + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]: + for layer, name in zip(self.layers, self.names): + if name == 'cross': + src0, src1 = desc1, desc0 + else: # if name == 'self': + src0, src1 = desc0, desc1 + delta0, delta1 = layer(desc0, src0), layer(desc1, src1) + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + return desc0, desc1 + + +def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor: + """ Perform Sinkhorn Normalization in Log-space for stability""" + u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) + for _ in range(iters): + u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) + v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) + return Z + u.unsqueeze(2) + v.unsqueeze(1) + + +def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor: + """ Perform Differentiable Optimal Transport in Log-space for stability""" + b, m, n = scores.shape + one = scores.new_tensor(1) + ms, ns = (m*one).to(scores), (n*one).to(scores) + + bins0 = alpha.expand(b, m, 1) + bins1 = alpha.expand(b, 1, n) + alpha = alpha.expand(b, 1, 1) + + couplings = torch.cat([torch.cat([scores, bins0], -1), + torch.cat([bins1, alpha], -1)], 1) + + norm = - (ms + ns).log() + log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) + log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) + log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) + + Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) + Z = Z - norm # multiply probabilities by M+N + return Z + + +def arange_like(x, dim: int): + return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 + + +class SuperGlue(nn.Module): + """SuperGlue feature matching middle-end + + Given two sets of keypoints and locations, we determine the + correspondences by: + 1. Keypoint Encoding (normalization + visual feature and location fusion) + 2. Graph Neural Network with multiple self and cross-attention layers + 3. Final projection layer + 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) + 5. Thresholding matrix based on mutual exclusivity and a match_threshold + + The correspondence ids use -1 to indicate non-matching points. + + Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew + Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural + Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 + + """ + default_config = { + 'descriptor_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, + 'sinkhorn_iterations': 100, + 'match_threshold': 0.2, + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + + self.kenc = KeypointEncoder( + self.config['descriptor_dim'], self.config['keypoint_encoder']) + + self.gnn = AttentionalGNN( + feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers']) + + self.final_proj = nn.Conv1d( + self.config['descriptor_dim'], self.config['descriptor_dim'], + kernel_size=1, bias=True) + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + + assert self.config['weights'] in ['indoor', 'outdoor'] + path = Path(__file__).parent + path = path / 'weights/superglue_{}.pth'.format(self.config['weights']) + self.load_state_dict(torch.load(str(path))) + print('Loaded SuperGlue model (\"{}\" weights)'.format( + self.config['weights'])) + + def forward(self, data): + """Run SuperGlue on a pair of keypoints and descriptors""" + desc0, desc1 = data['descriptors0'], data['descriptors1'] + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + + if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints + shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] + return { + 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int), + 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int), + 'matching_scores0': kpts0.new_zeros(shape0), + 'matching_scores1': kpts1.new_zeros(shape1), + } + + # Keypoint normalization. + kpts0 = normalize_keypoints(kpts0, data['image0'].shape) + kpts1 = normalize_keypoints(kpts1, data['image1'].shape) + + # Keypoint MLP encoder. + desc0 = desc0 + self.kenc(kpts0, data['scores0']) + desc1 = desc1 + self.kenc(kpts1, data['scores1']) + + # Multi-layer Transformer network. + desc0, desc1 = self.gnn(desc0, desc1) + + # Final MLP projection. + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) + + # Compute matching descriptor distance. + scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) + scores = scores / self.config['descriptor_dim']**.5 + + # Run the optimal transport. + scores = log_optimal_transport( + scores, self.bin_score, + iters=self.config['sinkhorn_iterations']) + + # Get the matches with score above "match_threshold". + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + indices0, indices1 = max0.indices, max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid1 = mutual1 & valid0.gather(1, indices1) + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + return { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores1, + } diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/superpoint.py b/imcui/third_party/SuperGluePretrainedNetwork/models/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..0577e1ec47c3397e45bc9a3cf2e47f211c32877e --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/models/superpoint.py @@ -0,0 +1,206 @@ +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +from pathlib import Path +import torch +from torch import nn + +def simple_nms(scores, nms_radius: int): + """ Fast Non-maximum suppression to remove nearby points """ + assert(nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def remove_borders(keypoints, scores, border: int, height: int, width: int): + """ Removes keypoints too close to the border """ + mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) + mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints, scores, k: int): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s: int = 8): + """ Interpolate descriptors at keypoint locations """ + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], + ).to(keypoints)[None] + keypoints = keypoints*2 - 1 # normalize to (-1, 1) + args = {'align_corners': True} if torch.__version__ >= '1.3' else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + return descriptors + + +class SuperPoint(nn.Module): + """SuperPoint Convolutional Detector and Descriptor + + SuperPoint: Self-Supervised Interest Point Detection and + Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew + Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 + + """ + default_config = { + 'descriptor_dim': 256, + 'nms_radius': 4, + 'keypoint_threshold': 0.005, + 'max_keypoints': -1, + 'remove_borders': 4, + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d( + c5, self.config['descriptor_dim'], + kernel_size=1, stride=1, padding=0) + + path = Path(__file__).parent / 'weights/superpoint_v1.pth' + self.load_state_dict(torch.load(str(path))) + + mk = self.config['max_keypoints'] + if mk == 0 or mk < -1: + raise ValueError('\"max_keypoints\" must be positive or \"-1\"') + + print('Loaded SuperPoint model') + + def forward(self, data, cfg={}): + """Compute keypoints, scores, descriptors for image""" + self.config = { + **self.config, + **cfg, + } + # Shared Encoder + x = self.relu(self.conv1a(data['image'])) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) + scores = simple_nms(scores, self.config['nms_radius']) + + # Extract keypoints + keypoints = [ + torch.nonzero(s > self.config['keypoint_threshold']) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, self.config['remove_borders'], h*8, w*8) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with highest score + if self.config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, self.config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + # Extract descriptors + descriptors = [sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors)] + + return { + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + } diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/utils.py b/imcui/third_party/SuperGluePretrainedNetwork/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1206244aa2a004d9f653782de798bfef9e5e726b --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/models/utils.py @@ -0,0 +1,555 @@ +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# Daniel DeTone +# Tomasz Malisiewicz +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +from pathlib import Path +import time +from collections import OrderedDict +from threading import Thread +import numpy as np +import cv2 +import torch +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('Agg') + + +class AverageTimer: + """ Class to help manage printing simple timing of code execution. """ + + def __init__(self, smoothing=0.3, newline=False): + self.smoothing = smoothing + self.newline = newline + self.times = OrderedDict() + self.will_print = OrderedDict() + self.reset() + + def reset(self): + now = time.time() + self.start = now + self.last_time = now + for name in self.will_print: + self.will_print[name] = False + + def update(self, name='default'): + now = time.time() + dt = now - self.last_time + if name in self.times: + dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name] + self.times[name] = dt + self.will_print[name] = True + self.last_time = now + + def print(self, text='Timer'): + total = 0. + print('[{}]'.format(text), end=' ') + for key in self.times: + val = self.times[key] + if self.will_print[key]: + print('%s=%.3f' % (key, val), end=' ') + total += val + print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ') + if self.newline: + print(flush=True) + else: + print(end='\r', flush=True) + self.reset() + + +class VideoStreamer: + """ Class to help process image streams. Four types of possible inputs:" + 1.) USB Webcam. + 2.) An IP camera + 3.) A directory of images (files in directory matching 'image_glob'). + 4.) A video file, such as an .mp4 or .avi file. + """ + def __init__(self, basedir, resize, skip, image_glob, max_length=1000000): + self._ip_grabbed = False + self._ip_running = False + self._ip_camera = False + self._ip_image = None + self._ip_index = 0 + self.cap = [] + self.camera = True + self.video_file = False + self.listing = [] + self.resize = resize + self.interp = cv2.INTER_AREA + self.i = 0 + self.skip = skip + self.max_length = max_length + if isinstance(basedir, int) or basedir.isdigit(): + print('==> Processing USB webcam input: {}'.format(basedir)) + self.cap = cv2.VideoCapture(int(basedir)) + self.listing = range(0, self.max_length) + elif basedir.startswith(('http', 'rtsp')): + print('==> Processing IP camera input: {}'.format(basedir)) + self.cap = cv2.VideoCapture(basedir) + self.start_ip_camera_thread() + self._ip_camera = True + self.listing = range(0, self.max_length) + elif Path(basedir).is_dir(): + print('==> Processing image directory input: {}'.format(basedir)) + self.listing = list(Path(basedir).glob(image_glob[0])) + for j in range(1, len(image_glob)): + image_path = list(Path(basedir).glob(image_glob[j])) + self.listing = self.listing + image_path + self.listing.sort() + self.listing = self.listing[::self.skip] + self.max_length = np.min([self.max_length, len(self.listing)]) + if self.max_length == 0: + raise IOError('No images found (maybe bad \'image_glob\' ?)') + self.listing = self.listing[:self.max_length] + self.camera = False + elif Path(basedir).exists(): + print('==> Processing video input: {}'.format(basedir)) + self.cap = cv2.VideoCapture(basedir) + self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) + num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.listing = range(0, num_frames) + self.listing = self.listing[::self.skip] + self.video_file = True + self.max_length = np.min([self.max_length, len(self.listing)]) + self.listing = self.listing[:self.max_length] + else: + raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir)) + if self.camera and not self.cap.isOpened(): + raise IOError('Could not read camera') + + def load_image(self, impath): + """ Read image as grayscale and resize to img_size. + Inputs + impath: Path to input image. + Returns + grayim: uint8 numpy array sized H x W. + """ + grayim = cv2.imread(impath, 0) + if grayim is None: + raise Exception('Error reading image %s' % impath) + w, h = grayim.shape[1], grayim.shape[0] + w_new, h_new = process_resize(w, h, self.resize) + grayim = cv2.resize( + grayim, (w_new, h_new), interpolation=self.interp) + return grayim + + def next_frame(self): + """ Return the next frame, and increment internal counter. + Returns + image: Next H x W image. + status: True or False depending whether image was loaded. + """ + + if self.i == self.max_length: + return (None, False) + if self.camera: + + if self._ip_camera: + #Wait for first image, making sure we haven't exited + while self._ip_grabbed is False and self._ip_exited is False: + time.sleep(.001) + + ret, image = self._ip_grabbed, self._ip_image.copy() + if ret is False: + self._ip_running = False + else: + ret, image = self.cap.read() + if ret is False: + print('VideoStreamer: Cannot get image from camera') + return (None, False) + w, h = image.shape[1], image.shape[0] + if self.video_file: + self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i]) + + w_new, h_new = process_resize(w, h, self.resize) + image = cv2.resize(image, (w_new, h_new), + interpolation=self.interp) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + else: + image_file = str(self.listing[self.i]) + image = self.load_image(image_file) + self.i = self.i + 1 + return (image, True) + + def start_ip_camera_thread(self): + self._ip_thread = Thread(target=self.update_ip_camera, args=()) + self._ip_running = True + self._ip_thread.start() + self._ip_exited = False + return self + + def update_ip_camera(self): + while self._ip_running: + ret, img = self.cap.read() + if ret is False: + self._ip_running = False + self._ip_exited = True + self._ip_grabbed = False + return + + self._ip_image = img + self._ip_grabbed = ret + self._ip_index += 1 + #print('IPCAMERA THREAD got frame {}'.format(self._ip_index)) + + + def cleanup(self): + self._ip_running = False + +# --- PREPROCESSING --- + +def process_resize(w, h, resize): + assert(len(resize) > 0 and len(resize) <= 2) + if len(resize) == 1 and resize[0] > -1: + scale = resize[0] / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + elif len(resize) == 1 and resize[0] == -1: + w_new, h_new = w, h + else: # len(resize) == 2: + w_new, h_new = resize[0], resize[1] + + # Issue warning if resolution is too small or too large. + if max(w_new, h_new) < 160: + print('Warning: input resolution is very small, results may vary') + elif max(w_new, h_new) > 2000: + print('Warning: input resolution is very large, results may vary') + + return w_new, h_new + + +def frame2tensor(frame, device): + return torch.from_numpy(frame/255.).float()[None, None].to(device) + + +def read_image(path, device, resize, rotation, resize_float): + image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) + if image is None: + return None, None, None + w, h = image.shape[1], image.shape[0] + w_new, h_new = process_resize(w, h, resize) + scales = (float(w) / float(w_new), float(h) / float(h_new)) + + if resize_float: + image = cv2.resize(image.astype('float32'), (w_new, h_new)) + else: + image = cv2.resize(image, (w_new, h_new)).astype('float32') + + if rotation != 0: + image = np.rot90(image, k=rotation) + if rotation % 2: + scales = scales[::-1] + + inp = frame2tensor(image, device) + return image, inp, scales + + +# --- GEOMETRY --- + + +def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + if len(kpts0) < 5: + return None + + f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) + norm_thresh = thresh / f_mean + + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, + method=cv2.RANSAC) + + assert E is not None + + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose( + _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t[:, 0], mask.ravel() > 0) + return ret + + +def rotate_intrinsics(K, image_shape, rot): + """image_shape is the shape of the image after rotation""" + assert rot <= 3 + h, w = image_shape[:2][::-1 if (rot % 2) else 1] + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + rot = rot % 4 + if rot == 1: + return np.array([[fy, 0., cy], + [0., fx, w-1-cx], + [0., 0., 1.]], dtype=K.dtype) + elif rot == 2: + return np.array([[fx, 0., w-1-cx], + [0., fy, h-1-cy], + [0., 0., 1.]], dtype=K.dtype) + else: # if rot == 3: + return np.array([[fy, 0., h-1-cy], + [0., fx, cx], + [0., 0., 1.]], dtype=K.dtype) + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array([[np.cos(r), -np.sin(r), 0., 0.], + [np.sin(r), np.cos(r), 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]], dtype=np.float32) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1./scales[0], 1./scales[1], 1.]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1): + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + kpts0 = to_homogeneous(kpts0) + kpts1 = to_homogeneous(kpts1) + + t0, t1, t2 = T_0to1[:3, 3] + t_skew = np.array([ + [0, -t2, t1], + [t2, 0, -t0], + [-t1, t0, 0] + ]) + E = t_skew @ T_0to1[:3, :3] + + Ep0 = kpts0 @ E.T # N x 3 + p1Ep0 = np.sum(kpts1 * Ep0, -1) # N + Etp1 = kpts1 @ E # N x 3 + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) + return d + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t, t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0., errors] + recall = np.r_[0., recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index-1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e)/t) + return aucs + + +# --- VISUALIZATION --- + + +def plot_image_pair(imgs, dpi=100, size=6, pad=.5): + n = len(imgs) + assert n == 2, 'number of images must be two' + figsize = (size*n, size*3/4) if size is not None else None + _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + plt.tight_layout(pad=pad) + + +def plot_keypoints(kpts0, kpts1, color='w', ps=2): + ax = plt.gcf().axes + ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4): + fig = plt.gcf() + ax = fig.axes + fig.canvas.draw() + + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1)) + + fig.lines = [matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1, + transform=fig.transFigure, c=color[i], linewidth=lw) + for i in range(len(kpts0))] + ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1, + color, text, path, show_keypoints=False, + fast_viz=False, opencv_display=False, + opencv_title='matches', small_text=[]): + + if fast_viz: + make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1, + color, text, path, show_keypoints, 10, + opencv_display, opencv_title, small_text) + return + + plot_image_pair([image0, image1]) + if show_keypoints: + plot_keypoints(kpts0, kpts1, color='k', ps=4) + plot_keypoints(kpts0, kpts1, color='w', ps=2) + plot_matches(mkpts0, mkpts1, color) + + fig = plt.gcf() + txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w' + fig.text( + 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes, + fontsize=5, va='bottom', ha='left', color=txt_color) + + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + + +def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, + mkpts1, color, text, path=None, + show_keypoints=False, margin=10, + opencv_display=False, opencv_title='', + small_text=[]): + H0, W0 = image0.shape + H1, W1 = image1.shape + H, W = max(H0, H1), W0 + W1 + margin + + out = 255*np.ones((H, W), np.uint8) + out[:H0, :W0] = image0 + out[:H1, W0+margin:] = image1 + out = np.stack([out]*3, -1) + + if show_keypoints: + kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) + white = (255, 255, 255) + black = (0, 0, 0) + for x, y in kpts0: + cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA) + for x, y in kpts1: + cv2.circle(out, (x + margin + W0, y), 2, black, -1, + lineType=cv2.LINE_AA) + cv2.circle(out, (x + margin + W0, y), 1, white, -1, + lineType=cv2.LINE_AA) + + mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) + color = (np.array(color[:, :3])*255).astype(int)[:, ::-1] + for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color): + c = c.tolist() + cv2.line(out, (x0, y0), (x1 + margin + W0, y1), + color=c, thickness=1, lineType=cv2.LINE_AA) + # display line end-points as circles + cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, + lineType=cv2.LINE_AA) + + # Scale factor for consistent visualization across scales. + sc = min(H / 640., 2.0) + + # Big text. + Ht = int(30 * sc) # text height + txt_color_fg = (255, 255, 255) + txt_color_bg = (0, 0, 0) + for i, t in enumerate(text): + cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, + 1.0*sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, + 1.0*sc, txt_color_fg, 1, cv2.LINE_AA) + + # Small text. + Ht = int(18 * sc) # text height + for i, t in enumerate(reversed(small_text)): + cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, + 0.5*sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, + 0.5*sc, txt_color_fg, 1, cv2.LINE_AA) + + if path is not None: + cv2.imwrite(str(path), out) + + if opencv_display: + cv2.imshow(opencv_title, out) + cv2.waitKey(1) + + return out + + +def error_colormap(x): + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1) diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_indoor.pth b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_indoor.pth new file mode 100644 index 0000000000000000000000000000000000000000..969252133f802cb03256c15a3881b8b39c1867d4 --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_indoor.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e710469be25ebe1e2ccf68edcae8b2945b0617c8e7e68412251d9d47f5052b1 +size 48233807 diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_outdoor.pth b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_outdoor.pth new file mode 100644 index 0000000000000000000000000000000000000000..79db4b5340b02afca3cdd419672300bb009975af --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_outdoor.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f5f5e9bb3febf07b69df633c4c3ff7a17f8af26a023aae2b9303d22339195bd +size 48233807 diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superpoint_v1.pth b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superpoint_v1.pth new file mode 100644 index 0000000000000000000000000000000000000000..7648726e3a3dfa2581e86bfa9c5a2a05cfb9bf74 --- /dev/null +++ b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superpoint_v1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52b6708629640ca883673b5d5c097c4ddad37d8048b33f09c8ca0d69db12c40e +size 5206086 diff --git a/imcui/third_party/TopicFM/.github/workflows/sync.yml b/imcui/third_party/TopicFM/.github/workflows/sync.yml new file mode 100644 index 0000000000000000000000000000000000000000..efbf881c64bdeac6916473e4391e23e87af5b69d --- /dev/null +++ b/imcui/third_party/TopicFM/.github/workflows/sync.yml @@ -0,0 +1,39 @@ +name: Upstream Sync + +permissions: + contents: write + +on: + schedule: + - cron: "0 0 * * *" # every day + workflow_dispatch: + +jobs: + sync_latest_from_upstream: + name: Sync latest commits from upstream repo + runs-on: ubuntu-latest + if: ${{ github.event.repository.fork }} + + steps: + # Step 1: run a standard checkout action + - name: Checkout target repo + uses: actions/checkout@v3 + + # Step 2: run the sync action + - name: Sync upstream changes + id: sync + uses: aormsby/Fork-Sync-With-Upstream-action@v3.4 + with: + upstream_sync_repo: TruongKhang/TopicFM + upstream_sync_branch: main + target_sync_branch: main + target_repo_token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, no need to set + + # Set test_mode true to run tests instead of the true action!! + test_mode: false + + - name: Sync check + if: failure() + run: | + echo "::error::Due to insufficient permissions, synchronization failed (as expected). Please go to the repository homepage and manually perform [Sync fork]." + exit 1 diff --git a/imcui/third_party/TopicFM/configs/data/__init__.py b/imcui/third_party/TopicFM/configs/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/TopicFM/configs/data/base.py b/imcui/third_party/TopicFM/configs/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6cab7e67019a6fee2657c1a28609c8aca5b2a1d8 --- /dev/null +++ b/imcui/third_party/TopicFM/configs/data/base.py @@ -0,0 +1,37 @@ +""" +The data config will be the last one merged into the main config. +Setups in data configs will override all existed setups! +""" + +from yacs.config import CfgNode as CN +_CN = CN() +_CN.DATASET = CN() +_CN.TRAINER = CN() + +# training data config +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +# validation set config +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None +_CN.DATASET.VAL_INTRINSIC_PATH = None + +# testing data config +_CN.DATASET.TEST_DATA_SOURCE = None +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None +_CN.DATASET.TEST_INTRINSIC_PATH = None +_CN.DATASET.TEST_IMGSIZE = None + +# dataset config +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +cfg = _CN diff --git a/imcui/third_party/TopicFM/configs/data/megadepth_test_1500.py b/imcui/third_party/TopicFM/configs/data/megadepth_test_1500.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd107fc07ecd464f793d13282939ddb26032922 --- /dev/null +++ b/imcui/third_party/TopicFM/configs/data/megadepth_test_1500.py @@ -0,0 +1,11 @@ +from configs.data.base import cfg + +TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info" + +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" +cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" +cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" + +cfg.DATASET.MGDPT_IMG_RESIZE = 1200 +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 diff --git a/imcui/third_party/TopicFM/configs/data/megadepth_trainval.py b/imcui/third_party/TopicFM/configs/data/megadepth_trainval.py new file mode 100644 index 0000000000000000000000000000000000000000..215b5c34cc41d36aa4444a58ca0cb69afbc11952 --- /dev/null +++ b/imcui/third_party/TopicFM/configs/data/megadepth_trainval.py @@ -0,0 +1,22 @@ +from configs.data.base import cfg + + +TRAIN_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train" +cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" +cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 + +TEST_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" +cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" +cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +# 368 scenes in total for MegaDepth +# (with difficulty balanced (further split each scene to 3 sub-scenes)) +cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100 + +cfg.DATASET.MGDPT_IMG_RESIZE = 800 # for training on 11GB mem GPUs diff --git a/imcui/third_party/TopicFM/configs/data/scannet_test_1500.py b/imcui/third_party/TopicFM/configs/data/scannet_test_1500.py new file mode 100644 index 0000000000000000000000000000000000000000..ce3b0846b61c567b053d12fb636982ce02e21a5c --- /dev/null +++ b/imcui/third_party/TopicFM/configs/data/scannet_test_1500.py @@ -0,0 +1,12 @@ +from configs.data.base import cfg + +TEST_BASE_PATH = "assets/scannet_test_1500" + +cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" +cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" +cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" +cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" +cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" +cfg.DATASET.TEST_IMGSIZE = (640, 480) + +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 diff --git a/imcui/third_party/TopicFM/configs/model/indoor/model_cfg_test.py b/imcui/third_party/TopicFM/configs/model/indoor/model_cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8872d3b79de529aa375127ea5beb7e81d9d5b1 --- /dev/null +++ b/imcui/third_party/TopicFM/configs/model/indoor/model_cfg_test.py @@ -0,0 +1,4 @@ +from src.config.default import _CN as cfg + +cfg.MODEL.COARSE.N_SAMPLES = 5 +cfg.MODEL.MATCH_COARSE.THR = 0.3 diff --git a/imcui/third_party/TopicFM/configs/model/outdoor/model_cfg_test.py b/imcui/third_party/TopicFM/configs/model/outdoor/model_cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..692497457c2a7b9ad823f94546e38f15732ca632 --- /dev/null +++ b/imcui/third_party/TopicFM/configs/model/outdoor/model_cfg_test.py @@ -0,0 +1,4 @@ +from src.config.default import _CN as cfg + +cfg.MODEL.COARSE.N_SAMPLES = 10 +cfg.MODEL.MATCH_COARSE.THR = 0.2 diff --git a/imcui/third_party/TopicFM/configs/model/outdoor/model_ds.py b/imcui/third_party/TopicFM/configs/model/outdoor/model_ds.py new file mode 100644 index 0000000000000000000000000000000000000000..2c090edbfbdcd66cea225c39af6f62da8feb50b9 --- /dev/null +++ b/imcui/third_party/TopicFM/configs/model/outdoor/model_ds.py @@ -0,0 +1,16 @@ +from src.config.default import _CN as cfg + +cfg.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.MODEL.COARSE.N_SAMPLES = 8 + +cfg.TRAINER.CANONICAL_LR = 1e-2 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 +cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 16, 20, 24, 28] + +# pose estimation +cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 + +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 +cfg.MODEL.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 diff --git a/imcui/third_party/TopicFM/flop_counter.py b/imcui/third_party/TopicFM/flop_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..ea87fa0139897434ca52b369450aa82203311181 --- /dev/null +++ b/imcui/third_party/TopicFM/flop_counter.py @@ -0,0 +1,55 @@ +import torch +from fvcore.nn import FlopCountAnalysis +from einops.einops import rearrange + +from src import get_model_cfg +from src.models.backbone import FPN as topicfm_featnet +from src.models.modules import TopicFormer +from src.utils.dataset import read_scannet_gray + +from third_party.loftr.src.loftr.utils.cvpr_ds_config import default_cfg +from third_party.loftr.src.loftr.backbone import ResNetFPN_8_2 as loftr_featnet +from third_party.loftr.src.loftr.loftr_module import LocalFeatureTransformer + + +def feat_net_flops(feat_net, config, input): + model = feat_net(config) + model.eval() + flops = FlopCountAnalysis(model, input) + feat_c, _ = model(input) + return feat_c, flops.total() / 1e9 + + +def coarse_model_flops(coarse_model, config, inputs): + model = coarse_model(config) + model.eval() + flops = FlopCountAnalysis(model, inputs) + return flops.total() / 1e9 + + +if __name__ == '__main__': + path_img0 = "assets/scannet_sample_images/scene0711_00_frame-001680.jpg" + path_img1 = "assets/scannet_sample_images/scene0711_00_frame-001995.jpg" + img0, img1 = read_scannet_gray(path_img0), read_scannet_gray(path_img1) + img0, img1 = img0.unsqueeze(0), img1.unsqueeze(0) + + # LoFTR + loftr_conf = dict(default_cfg) + feat_c0, loftr_featnet_flops0 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img0) + feat_c1, loftr_featnet_flops1 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img1) + print("FLOPs of feature extraction in LoFTR: {} GFLOPs".format((loftr_featnet_flops0 + loftr_featnet_flops1)/2)) + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + loftr_coarse_model_flops = coarse_model_flops(LocalFeatureTransformer, loftr_conf["coarse"], (feat_c0, feat_c1)) + print("FLOPs of coarse matching model in LoFTR: {} GFLOPs".format(loftr_coarse_model_flops)) + + # TopicFM + topicfm_conf = get_model_cfg() + feat_c0, topicfm_featnet_flops0 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img0) + feat_c1, topicfm_featnet_flops1 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img1) + print("FLOPs of feature extraction in TopicFM: {} GFLOPs".format((topicfm_featnet_flops0 + topicfm_featnet_flops1) / 2)) + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + topicfm_coarse_model_flops = coarse_model_flops(TopicFormer, topicfm_conf["coarse"], (feat_c0, feat_c1)) + print("FLOPs of coarse matching model in TopicFM: {} GFLOPs".format(topicfm_coarse_model_flops)) + diff --git a/imcui/third_party/TopicFM/src/__init__.py b/imcui/third_party/TopicFM/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..30caef94f911f99e0c12510d8181b3c1537daf1a --- /dev/null +++ b/imcui/third_party/TopicFM/src/__init__.py @@ -0,0 +1,11 @@ +from yacs.config import CfgNode +from .config.default import _CN + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CfgNode): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + +def get_model_cfg(): + cfg = lower_config(lower_config(_CN)) + return cfg["model"] \ No newline at end of file diff --git a/imcui/third_party/TopicFM/src/config/default.py b/imcui/third_party/TopicFM/src/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..591558b3f358cdce0e9e72e94acba702b2a4e896 --- /dev/null +++ b/imcui/third_party/TopicFM/src/config/default.py @@ -0,0 +1,171 @@ +from yacs.config import CfgNode as CN +_CN = CN() + +############## ↓ MODEL Pipeline ↓ ############## +_CN.MODEL = CN() +_CN.MODEL.BACKBONE_TYPE = 'FPN' +_CN.MODEL.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.MODEL.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.MODEL.FINE_CONCAT_COARSE_FEAT = False + +# 1. MODEL-backbone (local feature CNN) config +_CN.MODEL.FPN = CN() +_CN.MODEL.FPN.INITIAL_DIM = 128 +_CN.MODEL.FPN.BLOCK_DIMS = [128, 192, 256, 384] # s1, s2, s3 + +# 2. MODEL-coarse module config +_CN.MODEL.COARSE = CN() +_CN.MODEL.COARSE.D_MODEL = 256 +_CN.MODEL.COARSE.D_FFN = 256 +_CN.MODEL.COARSE.NHEAD = 8 +_CN.MODEL.COARSE.LAYER_NAMES = ['seed', 'seed', 'seed', 'seed', 'seed'] +_CN.MODEL.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.MODEL.COARSE.TEMP_BUG_FIX = True +_CN.MODEL.COARSE.N_TOPICS = 100 +_CN.MODEL.COARSE.N_SAMPLES = 6 +_CN.MODEL.COARSE.N_TOPIC_TRANSFORMERS = 1 + +# 3. Coarse-Matching config +_CN.MODEL.MATCH_COARSE = CN() +_CN.MODEL.MATCH_COARSE.THR = 0.2 +_CN.MODEL.MATCH_COARSE.BORDER_RM = 2 +_CN.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +_CN.MODEL.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MODEL.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.MODEL.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock +_CN.MODEL.MATCH_COARSE.SPARSE_SPVS = True + +# 4. MODEL-fine module config +_CN.MODEL.FINE = CN() +_CN.MODEL.FINE.D_MODEL = 128 +_CN.MODEL.FINE.D_FFN = 128 +_CN.MODEL.FINE.NHEAD = 4 +_CN.MODEL.FINE.LAYER_NAMES = ['cross'] * 1 +_CN.MODEL.FINE.ATTENTION = 'linear' +_CN.MODEL.FINE.N_TOPICS = 1 + +# 5. MODEL Losses +# -- # coarse-level +_CN.MODEL.LOSS = CN() +_CN.MODEL.LOSS.COARSE_WEIGHT = 1.0 +# _CN.MODEL.LOSS.SPARSE_SPVS = False +# -- - -- # focal loss (coarse) +_CN.MODEL.LOSS.FOCAL_ALPHA = 0.25 +_CN.MODEL.LOSS.POS_WEIGHT = 1.0 +_CN.MODEL.LOSS.NEG_WEIGHT = 1.0 +# _CN.MODEL.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not. +# use `_CN.MODEL.MATCH_COARSE.MATCH_TYPE` + +# -- # fine-level +_CN.MODEL.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.MODEL.LOSS.FINE_WEIGHT = 1.0 +_CN.MODEL.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) + + +############## Dataset ############## +_CN.DATASET = CN() +# 1. data config +# training and validating +_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file +_CN.DATASET.VAL_INTRINSIC_PATH = None +# testing +_CN.DATASET.TEST_DATA_SOURCE = None +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file +_CN.DATASET.TEST_INTRINSIC_PATH = None +_CN.DATASET.TEST_IMGSIZE = None + +# 2. dataset config +# general options +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 +_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] + +# MegaDepth options +_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE +_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 +_CN.DATASET.MGDPT_DF = 8 + +############## Trainer ############## +_CN.TRAINER = CN() +_CN.TRAINER.WORLD_SIZE = 1 +_CN.TRAINER.CANONICAL_BS = 64 +_CN.TRAINER.CANONICAL_LR = 6e-3 +_CN.TRAINER.SCALING = None # this will be calculated automatically +_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning + +# optimizer +_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] +_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime +_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAMW_DECAY = 0.01 + +# step-based warm-up +_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_STEP = 4800 + +# learning rate scheduler +_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR +_CN.TRAINER.MSLR_GAMMA = 0.5 +_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing +_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval + +# plotting related +_CN.TRAINER.ENABLE_PLOTTING = True +_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting +_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] +_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' + +# geometric metrics and pose solver +_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.RANSAC_PIXEL_THR = 0.5 +_CN.TRAINER.RANSAC_CONF = 0.99999 +_CN.TRAINER.RANSAC_MAX_ITERS = 10000 +_CN.TRAINER.USE_MAGSACPP = False + +# data sampler for train_dataloader +_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] +# 'scene_balance' config +_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 +_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not +_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not +_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data +# 'random' config +_CN.TRAINER.RDM_REPLACEMENT = True +_CN.TRAINER.RDM_NUM_SAMPLES = None + +# gradient clipping +_CN.TRAINER.GRADIENT_CLIPPING = 0.5 + +# reproducibility +# This seed affects the data sampling. With the same seed, the data sampling is promised +# to be the same. When resume training from a checkpoint, it's better to use a different +# seed, otherwise the sampled data will be exactly the same as before resuming, which will +# cause less unique data items sampled during the entire training. +# Use of different seed values might affect the final training result, since not all data items +# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.) +_CN.TRAINER.SEED = 66 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/imcui/third_party/TopicFM/src/datasets/aachen.py b/imcui/third_party/TopicFM/src/datasets/aachen.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfeee4dbfbd78770976ec027ceee8ef333a4574 --- /dev/null +++ b/imcui/third_party/TopicFM/src/datasets/aachen.py @@ -0,0 +1,29 @@ +import os +from torch.utils.data import Dataset + +from src.utils.dataset import read_img_gray + + +class AachenDataset(Dataset): + def __init__(self, img_path, match_list_path, img_resize=None, down_factor=16): + self.img_path = img_path + self.img_resize = img_resize + self.down_factor = down_factor + with open(match_list_path, 'r') as f: + self.raw_pairs = f.readlines() + print("number of matching pairs: ", len(self.raw_pairs)) + + def __len__(self): + return len(self.raw_pairs) + + def __getitem__(self, idx): + raw_pair = self.raw_pairs[idx] + image_name0, image_name1 = raw_pair.strip('\n').split(' ') + path_img0 = os.path.join(self.img_path, image_name0) + path_img1 = os.path.join(self.img_path, image_name1) + img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor) + img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor) + return {"image0": img0, "image1": img1, + "scale0": scale0, "scale1": scale1, + "pair_names": (image_name0, image_name1), + "dataset_name": "AachenDayNight"} \ No newline at end of file diff --git a/imcui/third_party/TopicFM/src/datasets/custom_dataloader.py b/imcui/third_party/TopicFM/src/datasets/custom_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..46d55d4f4d56d2c96cd42b6597834f945a5eb20d --- /dev/null +++ b/imcui/third_party/TopicFM/src/datasets/custom_dataloader.py @@ -0,0 +1,126 @@ +from tqdm import tqdm +from os import path as osp +from torch.utils.data import Dataset, DataLoader, ConcatDataset + +from src.datasets.megadepth import MegaDepthDataset +from src.datasets.scannet import ScanNetDataset +from src.datasets.aachen import AachenDataset +from src.datasets.inloc import InLocDataset + + +class TestDataLoader(DataLoader): + """ + For distributed training, each training process is assgined + only a part of the training scenes to reduce memory overhead. + """ + + def __init__(self, config): + + # 1. data config + self.test_data_source = config.DATASET.TEST_DATA_SOURCE + dataset_name = str(self.test_data_source).lower() + # testing + self.test_data_root = config.DATASET.TEST_DATA_ROOT + self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) + self.test_npz_root = config.DATASET.TEST_NPZ_ROOT + self.test_list_path = config.DATASET.TEST_LIST_PATH + self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH + + # 2. dataset config + # general options + self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + + # MegaDepth options + if dataset_name == 'megadepth': + self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 800 + self.mgdpt_img_pad = True + self.mgdpt_depth_pad = True + self.mgdpt_df = 8 + self.coarse_scale = 0.125 + if dataset_name == 'scannet': + self.img_resize = config.DATASET.TEST_IMGSIZE + + if (dataset_name == 'megadepth') or (dataset_name == 'scannet'): + test_dataset = self._setup_dataset( + self.test_data_root, + self.test_npz_root, + self.test_list_path, + self.test_intrinsic_path, + mode='test', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.test_pose_root) + elif dataset_name == 'aachen_v1.1': + test_dataset = AachenDataset(self.test_data_root, self.test_list_path, + img_resize=config.DATASET.TEST_IMGSIZE) + elif dataset_name == 'inloc': + test_dataset = InLocDataset(self.test_data_root, self.test_list_path, + img_resize=config.DATASET.TEST_IMGSIZE) + else: + raise "unknown dataset" + + self.test_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': 4, + 'pin_memory': True + } + + # sampler = Seq(self.test_dataset, shuffle=False) + super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params) + + def _setup_dataset(self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode='train', + min_overlap_score=0., + pose_dir=None): + """ Setup train / val / test set""" + with open(scene_list_path, 'r') as f: + npz_names = [name.split()[0] for name in f.readlines()] + local_npz_names = npz_names + + return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path, + mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + + def _build_concat_dataset( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None + ): + datasets = [] + # augment_fn = self.augment_fn if mode == 'train' else None + data_source = self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + for npz_name in tqdm(npz_names): + # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. + npz_path = osp.join(npz_dir, npz_name) + if data_source == 'ScanNet': + datasets.append( + ScanNetDataset(data_root, + npz_path, + intrinsic_path, + mode=mode, img_resize=self.img_resize, + min_overlap_score=min_overlap_score, + pose_dir=pose_dir)) + elif data_source == 'MegaDepth': + datasets.append( + MegaDepthDataset(data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + coarse_scale=self.coarse_scale)) + else: + raise NotImplementedError() + return ConcatDataset(datasets) diff --git a/imcui/third_party/TopicFM/src/datasets/inloc.py b/imcui/third_party/TopicFM/src/datasets/inloc.py new file mode 100644 index 0000000000000000000000000000000000000000..5421099d11b4dbbea8c09568c493d844d5c6a1b0 --- /dev/null +++ b/imcui/third_party/TopicFM/src/datasets/inloc.py @@ -0,0 +1,29 @@ +import os +from torch.utils.data import Dataset + +from src.utils.dataset import read_img_gray + + +class InLocDataset(Dataset): + def __init__(self, img_path, match_list_path, img_resize=None, down_factor=16): + self.img_path = img_path + self.img_resize = img_resize + self.down_factor = down_factor + with open(match_list_path, 'r') as f: + self.raw_pairs = f.readlines() + print("number of matching pairs: ", len(self.raw_pairs)) + + def __len__(self): + return len(self.raw_pairs) + + def __getitem__(self, idx): + raw_pair = self.raw_pairs[idx] + image_name0, image_name1 = raw_pair.strip('\n').split(' ') + path_img0 = os.path.join(self.img_path, image_name0) + path_img1 = os.path.join(self.img_path, image_name1) + img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor) + img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor) + return {"image0": img0, "image1": img1, + "scale0": scale0, "scale1": scale1, + "pair_names": (image_name0, image_name1), + "dataset_name": "InLoc"} \ No newline at end of file diff --git a/imcui/third_party/TopicFM/src/datasets/megadepth.py b/imcui/third_party/TopicFM/src/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..e92768e72e373c2a8ebeaf1158f9710fb1bfb5f1 --- /dev/null +++ b/imcui/third_party/TopicFM/src/datasets/megadepth.py @@ -0,0 +1,129 @@ +import os.path as osp +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from loguru import logger + +from src.utils.dataset import read_megadepth_gray, read_megadepth_depth + + +class MegaDepthDataset(Dataset): + def __init__(self, + root_dir, + npz_path, + mode='train', + min_overlap_score=0.4, + img_resize=None, + df=None, + img_padding=False, + depth_padding=False, + augment_fn=None, + **kwargs): + """ + Manage one scene(npz_path) of MegaDepth dataset. + + Args: + root_dir (str): megadepth root directory that has `phoenix`. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + mode (str): options are ['train', 'val', 'test'] + min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. + img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. + This is useful during training with batches and testing with memory intensive algorithms. + df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. + img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. + depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. + augment_fn (callable, optional): augments images with pre-defined visual effects. + """ + super().__init__() + self.root_dir = root_dir + self.mode = mode + self.scene_id = npz_path.split('.')[0] + + # prepare scene_info and pair_info + if mode == 'test' and min_overlap_score != 0: + logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") + min_overlap_score = 0 + self.scene_info = np.load(npz_path, allow_pickle=True) + self.pair_infos = self.scene_info['pair_infos'].copy() + del self.scene_info['pair_infos'] + self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] + + # parameters for image resizing, padding and depthmap padding + if mode == 'train': + assert img_resize is not None and img_padding and depth_padding + self.img_resize = img_resize + if mode == 'val': + self.img_resize = 864 + self.df = df + self.img_padding = img_padding + self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + + def __len__(self): + return len(self.pair_infos) + + def __getitem__(self, idx): + (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] + + # read grayscale image and mask. (1, h, w) and (h, w) + img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) + img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) + + # TODO: Support augmentation & handle seeds for each worker correctly. + image0, mask0, scale0 = read_megadepth_gray( + img_name0, self.img_resize, self.df, self.img_padding, None) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + image1, mask1, scale1 = read_megadepth_gray( + img_name1, self.img_resize, self.df, self.img_padding, None) + # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + + # read depth. shape: (h, w) + if self.mode in ['train', 'val']: + depth0 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) + depth1 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) + else: + depth0 = depth1 = torch.tensor([]) + + # read intrinsics of original size + K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) + K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T0 = self.scene_info['poses'][idx0] + T1 = self.scene_info['poses'][idx1] + T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) + T_1to0 = T_0to1.inverse() + + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'MegaDepth', + 'scene_id': self.scene_id, + 'pair_id': idx, + 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), + } + + # for LoFTR training + if mask0 is not None: # img_padding is True + if self.coarse_scale: + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/TopicFM/src/datasets/sampler.py b/imcui/third_party/TopicFM/src/datasets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61 --- /dev/null +++ b/imcui/third_party/TopicFM/src/datasets/sampler.py @@ -0,0 +1,77 @@ +import torch +from torch.utils.data import Sampler, ConcatDataset + + +class RandomConcatSampler(Sampler): + """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset + in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. + However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. + + For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. + Args: + shuffle (bool): shuffle the random sampled indices across all sub-datsets. + repeat (int): repeatedly use the sampled indices multiple times for training. + [arXiv:1902.05509, arXiv:1901.09335] + NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) + NOTE: This sampler behaves differently with DistributedSampler. + It assume the dataset is splitted across ranks instead of replicated. + TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. + ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 + """ + def __init__(self, + data_source: ConcatDataset, + n_samples_per_subset: int, + subset_replacement: bool=True, + shuffle: bool=True, + repeat: int=1, + seed: int=None): + if not isinstance(data_source, ConcatDataset): + raise TypeError("data_source should be torch.utils.data.ConcatDataset") + + self.data_source = data_source + self.n_subset = len(self.data_source.datasets) + self.n_samples_per_subset = n_samples_per_subset + self.n_samples = self.n_subset * self.n_samples_per_subset * repeat + self.subset_replacement = subset_replacement + self.repeat = repeat + self.shuffle = shuffle + self.generator = torch.manual_seed(seed) + assert self.repeat >= 1 + + def __len__(self): + return self.n_samples + + def __iter__(self): + indices = [] + # sample from each sub-dataset + for d_idx in range(self.n_subset): + low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] + high = self.data_source.cumulative_sizes[d_idx] + if self.subset_replacement: + rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), + generator=self.generator, dtype=torch.int64) + else: # sample without replacement + len_subset = len(self.data_source.datasets[d_idx]) + rand_tensor = torch.randperm(len_subset, generator=self.generator) + low + if len_subset >= self.n_samples_per_subset: + rand_tensor = rand_tensor[:self.n_samples_per_subset] + else: # padding with replacement + rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), + generator=self.generator, dtype=torch.int64) + rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) + indices.append(rand_tensor) + indices = torch.cat(indices) + if self.shuffle: # shuffle the sampled dataset (from multiple subsets) + rand_tensor = torch.randperm(len(indices), generator=self.generator) + indices = indices[rand_tensor] + + # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) + if self.repeat > 1: + repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] + if self.shuffle: + _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] + repeat_indices = map(_choice, repeat_indices) + indices = torch.cat([indices, *repeat_indices], 0) + + assert indices.shape[0] == self.n_samples + return iter(indices.tolist()) diff --git a/imcui/third_party/TopicFM/src/datasets/scannet.py b/imcui/third_party/TopicFM/src/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5dab7b150a3c6f54eb07b0459bbf3e9ba58fbf --- /dev/null +++ b/imcui/third_party/TopicFM/src/datasets/scannet.py @@ -0,0 +1,115 @@ +from os import path as osp +from typing import Dict +from unicodedata import name + +import numpy as np +import torch +import torch.utils as utils +from numpy.linalg import inv +from src.utils.dataset import ( + read_scannet_gray, + read_scannet_depth, + read_scannet_pose, + read_scannet_intrinsic +) + + +class ScanNetDataset(utils.data.Dataset): + def __init__(self, + root_dir, + npz_path, + intrinsic_path, + mode='train', + min_overlap_score=0.4, + augment_fn=None, + pose_dir=None, + **kwargs): + """Manage one scene of ScanNet Dataset. + Args: + root_dir (str): ScanNet root directory that contains scene folders. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + intrinsic_path (str): path to depth-camera intrinsic file. + mode (str): options are ['train', 'val', 'test']. + augment_fn (callable, optional): augments images with pre-defined visual effects. + pose_dir (str): ScanNet root directory that contains all poses. + (we use a separate (optional) pose_dir since we store images and poses separately.) + """ + super().__init__() + self.root_dir = root_dir + self.pose_dir = pose_dir if pose_dir is not None else root_dir + self.mode = mode + self.img_resize = (640, 480) if 'img_resize' not in kwargs else kwargs['img_resize'] + + # prepare data_names, intrinsics and extrinsics(T) + with np.load(npz_path) as data: + self.data_names = data['name'] + if 'score' in data.keys() and mode not in ['val' or 'test']: + kept_mask = data['score'] > min_overlap_score + self.data_names = self.data_names[kept_mask] + self.intrinsics = dict(np.load(intrinsic_path)) + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + + def __len__(self): + return len(self.data_names) + + def _read_abs_pose(self, scene_name, name): + pth = osp.join(self.pose_dir, + scene_name, + 'pose', f'{name}.txt') + return read_scannet_pose(pth) + + def _compute_rel_pose(self, scene_name, name0, name1): + pose0 = self._read_abs_pose(scene_name, name0) + pose1 = self._read_abs_pose(scene_name, name1) + + return np.matmul(pose1, inv(pose0)) # (4, 4) + + def __getitem__(self, idx): + data_name = self.data_names[idx] + scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the grayscale image which will be resized to (1, 480, 640) + img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') + img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') + + # TODO: Support augmentation & handle seeds for each worker correctly. + image0 = read_scannet_gray(img_name0, resize=self.img_resize, augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + image1 = read_scannet_gray(img_name1, resize=self.img_resize, augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + + # read the depthmap which is stored as (480, 640) + if self.mode in ['train', 'val']: + depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) + depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) + else: + depth0 = depth1 = torch.tensor([]) + + # read the intrinsic of depthmap + K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), + dtype=torch.float32) + T_1to0 = T_0to1.inverse() + + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'dataset_name': 'ScanNet', + 'scene_id': scene_name, + 'pair_id': idx, + 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), + osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) + } + + return data diff --git a/imcui/third_party/TopicFM/src/lightning_trainer/data.py b/imcui/third_party/TopicFM/src/lightning_trainer/data.py new file mode 100644 index 0000000000000000000000000000000000000000..8deb713b6300e0e9e8a261e2230031174b452862 --- /dev/null +++ b/imcui/third_party/TopicFM/src/lightning_trainer/data.py @@ -0,0 +1,320 @@ +import os +import math +from collections import abc +from loguru import logger +from torch.utils.data.dataset import Dataset +from tqdm import tqdm +from os import path as osp +from pathlib import Path +from joblib import Parallel, delayed + +import pytorch_lightning as pl +from torch import distributed as dist +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset, + DistributedSampler, + RandomSampler, + dataloader +) + +from src.utils.augment import build_augmentor +from src.utils.dataloader import get_local_split +from src.utils.misc import tqdm_joblib +from src.utils import comm +from src.datasets.megadepth import MegaDepthDataset +from src.datasets.scannet import ScanNetDataset +from src.datasets.sampler import RandomConcatSampler + + +class MultiSceneDataModule(pl.LightningDataModule): + """ + For distributed training, each training process is assgined + only a part of the training scenes to reduce memory overhead. + """ + def __init__(self, args, config): + super().__init__() + + # 1. data config + # Train and Val should from the same data source + self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE + self.test_data_source = config.DATASET.TEST_DATA_SOURCE + # training and validating + self.train_data_root = config.DATASET.TRAIN_DATA_ROOT + self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional) + self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT + self.train_list_path = config.DATASET.TRAIN_LIST_PATH + self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH + self.val_data_root = config.DATASET.VAL_DATA_ROOT + self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional) + self.val_npz_root = config.DATASET.VAL_NPZ_ROOT + self.val_list_path = config.DATASET.VAL_LIST_PATH + self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH + # testing + self.test_data_root = config.DATASET.TEST_DATA_ROOT + self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) + self.test_npz_root = config.DATASET.TEST_NPZ_ROOT + self.test_list_path = config.DATASET.TEST_LIST_PATH + self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH + + # 2. dataset config + # general options + self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN + self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] + + # MegaDepth options + self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 + self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True + self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True + self.mgdpt_df = config.DATASET.MGDPT_DF # 8 + self.coarse_scale = 1 / config.MODEL.RESOLUTION[0] # 0.125. for training loftr. + + # 3.loader parameters + self.train_loader_params = { + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.val_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.test_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': True + } + + # 4. sampler + self.data_sampler = config.TRAINER.DATA_SAMPLER + self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET + self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT + self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE + self.repeat = config.TRAINER.SB_REPEAT + + # (optional) RandomSampler for debugging + + # misc configurations + self.parallel_load_data = getattr(args, 'parallel_load_data', False) + self.seed = config.TRAINER.SEED # 66 + + def setup(self, stage=None): + """ + Setup train / val / test dataset. This method will be called by PL automatically. + Args: + stage (str): 'fit' in training phase, and 'test' in testing phase. + """ + + assert stage in ['fit', 'test'], "stage must be either fit or test" + + try: + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") + except AssertionError as ae: + self.world_size = 1 + self.rank = 0 + logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") + + if stage == 'fit': + self.train_dataset = self._setup_dataset( + self.train_data_root, + self.train_npz_root, + self.train_list_path, + self.train_intrinsic_path, + mode='train', + min_overlap_score=self.min_overlap_score_train, + pose_dir=self.train_pose_root) + # setup multiple (optional) validation subsets + if isinstance(self.val_list_path, (list, tuple)): + self.val_dataset = [] + if not isinstance(self.val_npz_root, (list, tuple)): + self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] + for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): + self.val_dataset.append(self._setup_dataset( + self.val_data_root, + npz_root, + npz_list, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root)) + else: + self.val_dataset = self._setup_dataset( + self.val_data_root, + self.val_npz_root, + self.val_list_path, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root) + logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') + else: # stage == 'test + self.test_dataset = self._setup_dataset( + self.test_data_root, + self.test_npz_root, + self.test_list_path, + self.test_intrinsic_path, + mode='test', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.test_pose_root) + logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') + + def _setup_dataset(self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode='train', + min_overlap_score=0., + pose_dir=None): + """ Setup train / val / test set""" + with open(scene_list_path, 'r') as f: + npz_names = [name.split()[0] for name in f.readlines()] + + if mode == 'train': + local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) + else: + local_npz_names = npz_names + logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') + + dataset_builder = self._build_concat_dataset_parallel \ + if self.parallel_load_data \ + else self._build_concat_dataset + return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, + mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + + def _build_concat_dataset( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None + ): + datasets = [] + augment_fn = self.augment_fn if mode == 'train' else None + data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + for npz_name in tqdm(npz_names, + desc=f'[rank:{self.rank}] loading {mode} datasets', + disable=int(self.rank) != 0): + # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. + npz_path = osp.join(npz_dir, npz_name) + if data_source == 'ScanNet': + datasets.append( + ScanNetDataset(data_root, + npz_path, + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir)) + elif data_source == 'MegaDepth': + datasets.append( + MegaDepthDataset(data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale)) + else: + raise NotImplementedError() + return ConcatDataset(datasets) + + def _build_concat_dataset_parallel( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None, + ): + augment_fn = self.augment_fn if mode == 'train' else None + data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', + total=len(npz_names), disable=int(self.rank) != 0)): + if data_source == 'ScanNet': + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + ScanNetDataset, + data_root, + osp.join(npz_dir, x), + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir))(name) + for name in npz_names) + elif data_source == 'MegaDepth': + # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. + raise NotImplementedError() + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + MegaDepthDataset, + data_root, + osp.join(npz_dir, x), + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale))(name) + for name in npz_names) + else: + raise ValueError(f'Unknown dataset: {data_source}') + return ConcatDataset(datasets) + + def train_dataloader(self): + """ Build training dataloader for ScanNet / MegaDepth. """ + assert self.data_sampler in ['scene_balance'] + logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') + if self.data_sampler == 'scene_balance': + sampler = RandomConcatSampler(self.train_dataset, + self.n_samples_per_subset, + self.subset_replacement, + self.shuffle, self.repeat, self.seed) + else: + sampler = None + dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) + return dataloader + + def val_dataloader(self): + """ Build validation dataloader for ScanNet / MegaDepth. """ + logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') + if not isinstance(self.val_dataset, abc.Sequence): + sampler = DistributedSampler(self.val_dataset, shuffle=False) + return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) + else: + dataloaders = [] + for dataset in self.val_dataset: + sampler = DistributedSampler(dataset, shuffle=False) + dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) + return dataloaders + + def test_dataloader(self, *args, **kwargs): + logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') + sampler = DistributedSampler(self.test_dataset, shuffle=False) + return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) + + +def _build_dataset(dataset: Dataset, *args, **kwargs): + return dataset(*args, **kwargs) diff --git a/imcui/third_party/TopicFM/src/lightning_trainer/trainer.py b/imcui/third_party/TopicFM/src/lightning_trainer/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..acf51f66130be66b7d3294ca5c081a2df3856d96 --- /dev/null +++ b/imcui/third_party/TopicFM/src/lightning_trainer/trainer.py @@ -0,0 +1,244 @@ + +from collections import defaultdict +import pprint +from loguru import logger +from pathlib import Path + +import torch +import numpy as np +import pytorch_lightning as pl +from matplotlib import pyplot as plt + +from src.models import TopicFM +from src.models.utils.supervision import compute_supervision_coarse, compute_supervision_fine +from src.losses.loss import TopicFMLoss +from src.optimizers import build_optimizer, build_scheduler +from src.utils.metrics import ( + compute_symmetrical_epipolar_errors, + compute_pose_errors, + aggregate_metrics +) +from src.utils.plotting import make_matching_figures +from src.utils.comm import gather, all_gather +from src.utils.misc import lower_config, flattenList +from src.utils.profiler import PassThroughProfiler + + +class PL_Trainer(pl.LightningModule): + def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): + """ + TODO: + - use the new version of PL logging API. + """ + super().__init__() + # Misc + self.config = config # full config + _config = lower_config(self.config) + self.model_cfg = lower_config(_config['model']) + self.profiler = profiler or PassThroughProfiler() + self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + + # Matcher: TopicFM + self.matcher = TopicFM(config=_config['model']) + self.loss = TopicFMLoss(_config) + + # Pretrained weights + if pretrained_ckpt: + state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] + self.matcher.load_state_dict(state_dict, strict=True) + logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") + + # Testing + self.dump_dir = dump_dir + + def configure_optimizers(self): + # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` + optimizer = build_optimizer(self, self.config) + scheduler = build_scheduler(self.config, optimizer) + return [optimizer], [scheduler] + + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # learning rate warm up + warmup_step = self.config.TRAINER.WARMUP_STEP + if self.trainer.global_step < warmup_step: + if self.config.TRAINER.WARMUP_TYPE == 'linear': + base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR + lr = base_lr + \ + (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ + abs(self.config.TRAINER.TRUE_LR - base_lr) + for pg in optimizer.param_groups: + pg['lr'] = lr + elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pass + else: + raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + + # update params + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() + + def _trainval_inference(self, batch): + with self.profiler.profile("Compute coarse supervision"): + compute_supervision_coarse(batch, self.config) + + with self.profiler.profile("TopicFM"): + self.matcher(batch) + + with self.profiler.profile("Compute fine supervision"): + compute_supervision_fine(batch, self.config) + + with self.profiler.profile("Compute losses"): + self.loss(batch) + + def _compute_metrics(self, batch): + with self.profiler.profile("Copmute metrics"): + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair + + rel_pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'R_errs': batch['R_errs'], + 't_errs': batch['t_errs'], + 'inliers': batch['inliers']} + ret_dict = {'metrics': metrics} + return ret_dict, rel_pair_names + + def training_step(self, batch, batch_idx): + self._trainval_inference(batch) + + # logging + if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + # scalars + for k, v in batch['loss_scalars'].items(): + self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step) + + # figures + if self.config.TRAINER.ENABLE_PLOTTING: + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + for k, v in figures.items(): + self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step) + + return {'loss': batch['loss']} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + if self.trainer.global_rank == 0: + self.logger.experiment.add_scalar( + 'train/avg_loss_on_epoch', avg_loss, + global_step=self.current_epoch) + + def validation_step(self, batch, batch_idx): + self._trainval_inference(batch) + + ret_dict, _ = self._compute_metrics(batch) + + val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) + figures = {self.config.TRAINER.PLOT_MODE: []} + if batch_idx % val_plot_interval == 0: + figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE) + + return { + **ret_dict, + 'loss_scalars': batch['loss_scalars'], + 'figures': figures, + } + + def validation_epoch_end(self, outputs): + # handle multiple validation sets + multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + multi_val_metrics = defaultdict(list) + + for valset_idx, outputs in enumerate(multi_outputs): + # since pl performs sanity_check at the very begining of the training + cur_epoch = self.trainer.current_epoch + if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + cur_epoch = -1 + + # 1. loss_scalars: dict of list, on cpu + _loss_scalars = [o['loss_scalars'] for o in outputs] + loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + + # 2. val metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + for thr in [5, 10, 20]: + multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}']) + + # 3. figures + _figures = [o['figures'] for o in outputs] + figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} + + # tensorboard records only on rank 0 + if self.trainer.global_rank == 0: + for k, v in loss_scalars.items(): + mean_v = torch.stack(v).mean() + self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + + for k, v in val_metrics_4tb.items(): + self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch) + + for k, v in figures.items(): + if self.trainer.global_rank == 0: + for plot_idx, fig in enumerate(v): + self.logger.experiment.add_figure( + f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True) + plt.close('all') + + for thr in [5, 10, 20]: + # log on all ranks for ModelCheckpoint callback to work properly + self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this + + def test_step(self, batch, batch_idx): + with self.profiler.profile("TopicFM"): + self.matcher(batch) + + ret_dict, rel_pair_names = self._compute_metrics(batch) + + with self.profiler.profile("dump_results"): + if self.dump_dir is not None: + # dump results for further analysis + keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'} + pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].shape[0] + dumps = [] + for b_id in range(bs): + item = {} + mask = batch['m_bids'] == b_id + item['pair_names'] = pair_names[b_id] + item['identifier'] = '#'.join(rel_pair_names[b_id]) + for key in keys_to_save: + item[key] = batch[key][mask].cpu().numpy() + for key in ['R_errs', 't_errs', 'inliers']: + item[key] = batch[key][b_id] + dumps.append(item) + ret_dict['dumps'] = dumps + + return ret_dict + + def test_epoch_end(self, outputs): + # metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + + # [{key: [{...}, *#bs]}, *#batch] + if self.dump_dir is not None: + Path(self.dump_dir).mkdir(parents=True, exist_ok=True) + _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] + dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] + logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') + + if self.trainer.global_rank == 0: + print(self.profiler.summary()) + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + logger.info('\n' + pprint.pformat(val_metrics_4tb)) + if self.dump_dir is not None: + np.save(Path(self.dump_dir) / 'TopicFM_pred_eval', dumps) diff --git a/imcui/third_party/TopicFM/src/losses/loss.py b/imcui/third_party/TopicFM/src/losses/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4be58498579c9fe649ed0ce2d42f230e59cef581 --- /dev/null +++ b/imcui/third_party/TopicFM/src/losses/loss.py @@ -0,0 +1,182 @@ +from loguru import logger + +import torch +import torch.nn as nn + + +def sample_non_matches(pos_mask, match_ids=None, sampling_ratio=10): + # assert (pos_mask.shape == mask.shape) # [B, H*W, H*W] + if match_ids is not None: + HW = pos_mask.shape[1] + b_ids, i_ids, j_ids = match_ids + if len(b_ids) == 0: + return ~pos_mask + + neg_mask = torch.zeros_like(pos_mask) + probs = torch.ones((HW - 1)//3, device=pos_mask.device) + for _ in range(sampling_ratio): + d = torch.multinomial(probs, len(j_ids), replacement=True) + sampled_j_ids = (j_ids + d*3 + 1) % HW + neg_mask[b_ids, i_ids, sampled_j_ids] = True + # neg_mask = neg_matrix == 1 + else: + neg_mask = ~pos_mask + + return neg_mask + + +class TopicFMLoss(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config # config under the global namespace + self.loss_config = config['model']['loss'] + self.match_type = self.config['model']['match_coarse']['match_type'] + + # coarse-level + self.correct_thr = self.loss_config['fine_correct_thr'] + self.c_pos_w = self.loss_config['pos_weight'] + self.c_neg_w = self.loss_config['neg_weight'] + # fine-level + self.fine_type = self.loss_config['fine_type'] + + def compute_coarse_loss(self, conf, topic_mat, conf_gt, match_ids=None, weight=None): + """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt. + Args: + conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1) + conf_gt (torch.Tensor): (N, HW0, HW1) + weight (torch.Tensor): (N, HW0, HW1) + """ + pos_mask = conf_gt == 1 + neg_mask = sample_non_matches(pos_mask, match_ids=match_ids) + c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w + # corner case: no gt coarse-level match at all + if not pos_mask.any(): # assign a wrong gt + pos_mask[0, 0, 0] = True + if weight is not None: + weight[0, 0, 0] = 0. + c_pos_w = 0. + if not neg_mask.any(): + neg_mask[0, 0, 0] = True + if weight is not None: + weight[0, 0, 0] = 0. + c_neg_w = 0. + + conf = torch.clamp(conf, 1e-6, 1 - 1e-6) + alpha = self.loss_config['focal_alpha'] + + loss = 0.0 + if isinstance(topic_mat, torch.Tensor): + pos_topic = topic_mat[pos_mask] + loss_pos_topic = - alpha * (pos_topic + 1e-6).log() + neg_topic = topic_mat[neg_mask] + loss_neg_topic = - alpha * (1 - neg_topic + 1e-6).log() + if weight is not None: + loss_pos_topic = loss_pos_topic * weight[pos_mask] + loss_neg_topic = loss_neg_topic * weight[neg_mask] + loss = loss_pos_topic.mean() + loss_neg_topic.mean() + + pos_conf = conf[pos_mask] + loss_pos = - alpha * pos_conf.log() + # handle loss weights + if weight is not None: + # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out, + # but only through manually setting corresponding regions in sim_matrix to '-inf'. + loss_pos = loss_pos * weight[pos_mask] + + loss = loss + c_pos_w * loss_pos.mean() + + return loss + + def compute_fine_loss(self, expec_f, expec_f_gt): + if self.fine_type == 'l2_with_std': + return self._compute_fine_loss_l2_std(expec_f, expec_f_gt) + elif self.fine_type == 'l2': + return self._compute_fine_loss_l2(expec_f, expec_f_gt) + else: + raise NotImplementedError() + + def _compute_fine_loss_l2(self, expec_f, expec_f_gt): + """ + Args: + expec_f (torch.Tensor): [M, 2] + expec_f_gt (torch.Tensor): [M, 2] + """ + correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + if correct_mask.sum() == 0: + if self.training: # this seldomly happen when training, since we pad prediction with gt + logger.warning("assign a false supervision to avoid ddp deadlock") + correct_mask[0] = True + else: + return None + offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1) + return offset_l2.mean() + + def _compute_fine_loss_l2_std(self, expec_f, expec_f_gt): + """ + Args: + expec_f (torch.Tensor): [M, 3] + expec_f_gt (torch.Tensor): [M, 2] + """ + # correct_mask tells you which pair to compute fine-loss + correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr + + # use std as weight that measures uncertainty + std = expec_f[:, 2] + inverse_std = 1. / torch.clamp(std, min=1e-10) + weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std + + # corner case: no correct coarse match found + if not correct_mask.any(): + if self.training: # this seldomly happen during training, since we pad prediction with gt + # sometimes there is not coarse-level gt at all. + logger.warning("assign a false supervision to avoid ddp deadlock") + correct_mask[0] = True + weight[0] = 0. + else: + return None + + # l2 loss with std + offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1) + loss = (offset_l2 * weight[correct_mask]).mean() + + return loss + + @torch.no_grad() + def compute_c_weight(self, data): + """ compute element-wise weights for computing coarse-level loss. """ + if 'mask0' in data: + c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float() + else: + c_weight = None + return c_weight + + def forward(self, data): + """ + Update: + data (dict): update{ + 'loss': [1] the reduced loss across a batch, + 'loss_scalars' (dict): loss scalars for tensorboard_record + } + """ + loss_scalars = {} + # 0. compute element-wise loss weight + c_weight = self.compute_c_weight(data) + + # 1. coarse-level loss + loss_c = self.compute_coarse_loss(data['conf_matrix'], data['topic_matrix'], + data['conf_matrix_gt'], match_ids=(data['spv_b_ids'], data['spv_i_ids'], data['spv_j_ids']), + weight=c_weight) + loss = loss_c * self.loss_config['coarse_weight'] + loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) + + # 2. fine-level loss + loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt']) + if loss_f is not None: + loss += loss_f * self.loss_config['fine_weight'] + loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) + else: + assert self.training is False + loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound + + loss_scalars.update({'loss': loss.clone().detach().cpu()}) + data.update({"loss": loss, "loss_scalars": loss_scalars}) diff --git a/imcui/third_party/TopicFM/src/models/__init__.py b/imcui/third_party/TopicFM/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9abdbdaebbf6c91a6fdc24e23d62c73003b204bf --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/__init__.py @@ -0,0 +1 @@ +from .topic_fm import TopicFM diff --git a/imcui/third_party/TopicFM/src/models/backbone/__init__.py b/imcui/third_party/TopicFM/src/models/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53f98db4e910b46716bed7cfc6ebbf8c8bfad399 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/backbone/__init__.py @@ -0,0 +1,5 @@ +from .fpn import FPN + + +def build_backbone(config): + return FPN(config['fpn']) diff --git a/imcui/third_party/TopicFM/src/models/backbone/fpn.py b/imcui/third_party/TopicFM/src/models/backbone/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..93cc475f57317f9dbb8132cdfe0297391972f9e2 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/backbone/fpn.py @@ -0,0 +1,109 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1, bn=True): + super().__init__() + self.conv = conv3x3(in_planes, planes, stride) + self.bn = nn.BatchNorm2d(planes) if bn is True else None + self.act = nn.GELU() + + def forward(self, x): + y = self.conv(x) + if self.bn: + y = self.bn(y) #F.layer_norm(y, y.shape[1:]) + y = self.act(y) + return y + + +class FPN(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = ConvBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + ConvBlock(block_dims[3], block_dims[2]), + conv3x3(block_dims[2], block_dims[2]), + ) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + ConvBlock(block_dims[2], block_dims[1]), + conv3x3(block_dims[1], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + ConvBlock(block_dims[1], block_dims[0]), + conv3x3(block_dims[0], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out_2x = F.interpolate(x4, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] diff --git a/imcui/third_party/TopicFM/src/models/modules/__init__.py b/imcui/third_party/TopicFM/src/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59cf36da37104dcf080e1b2c119c8123fa8d147f --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/modules/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer, TopicFormer +from .fine_preprocess import FinePreprocess diff --git a/imcui/third_party/TopicFM/src/models/modules/fine_preprocess.py b/imcui/third_party/TopicFM/src/models/modules/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4c8d264c1895be8f4e124fc3982d4e0d3b876af3 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/modules/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/imcui/third_party/TopicFM/src/models/modules/linear_attention.py b/imcui/third_party/TopicFM/src/models/modules/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..af6cd825033e98b7be15cc694ce28110ef84cc93 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/modules/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]).bool(), -1e9) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() diff --git a/imcui/third_party/TopicFM/src/models/modules/transformer.py b/imcui/third_party/TopicFM/src/models/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..27ff8f6554844b1e14a7094fcbad40876f766db8 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/modules/transformer.py @@ -0,0 +1,232 @@ +from loguru import logger +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .linear_attention import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.GELU(), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.shape[0] + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class TopicFormer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(TopicFormer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + + self.topic_transformers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2*config['n_topic_transformers'])]) if config['n_samples'] > 0 else None #nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) + self.n_iter_topic_transformer = config['n_topic_transformers'] + + self.seed_tokens = nn.Parameter(torch.randn(config['n_topics'], config['d_model'])) + self.register_parameter('seed_tokens', self.seed_tokens) + self.n_samples = config['n_samples'] + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def sample_topic(self, prob_topics, topics, L): + """ + Args: + topics (torch.Tensor): [N, L+S, K] + """ + prob_topics0, prob_topics1 = prob_topics[:, :L], prob_topics[:, L:] + topics0, topics1 = topics[:, :L], topics[:, L:] + + theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1) # [N, K] + theta1 = F.normalize(prob_topics1.sum(dim=1), p=1, dim=-1) + theta = F.normalize(theta0 * theta1, p=1, dim=-1) + if self.n_samples == 0: + return None + if self.training: + sampled_inds = torch.multinomial(theta, self.n_samples) + sampled_values = torch.gather(theta, dim=-1, index=sampled_inds) + else: + sampled_values, sampled_inds = torch.topk(theta, self.n_samples, dim=-1) + sampled_topics0 = torch.gather(topics0, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1)) + sampled_topics1 = torch.gather(topics1, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1)) + return sampled_topics0, sampled_topics1 + + def reduce_feat(self, feat, topick, N, C): + len_topic = topick.sum(dim=-1).int() + max_len = len_topic.max().item() + selected_ids = topick.bool() + resized_feat = torch.zeros((N, max_len, C), dtype=torch.float32, device=feat.device) + new_mask = torch.zeros_like(resized_feat[..., 0]).bool() + for i in range(N): + new_mask[i, :len_topic[i]] = True + resized_feat[new_mask, :] = feat[selected_ids, :] + return resized_feat, new_mask, selected_ids + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal" + N, L, S, C, K = feat0.shape[0], feat0.shape[1], feat1.shape[1], feat0.shape[2], self.config['n_topics'] + + seeds = self.seed_tokens.unsqueeze(0).repeat(N, 1, 1) + + feat = torch.cat((feat0, feat1), dim=1) + if mask0 is not None: + mask = torch.cat((mask0, mask1), dim=-1) + else: + mask = None + + for layer, name in zip(self.layers, self.layer_names): + if name == 'seed': + # seeds = layer(seeds, feat0, None, mask0) + # seeds = layer(seeds, feat1, None, mask1) + seeds = layer(seeds, feat, None, mask) + elif name == 'feat': + feat0 = layer(feat0, seeds, mask0, None) + feat1 = layer(feat1, seeds, mask1, None) + + dmatrix = torch.einsum("nmd,nkd->nmk", feat, seeds) + prob_topics = F.softmax(dmatrix, dim=-1) + + feat_topics = torch.zeros_like(dmatrix).scatter_(-1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0) + + if mask is not None: + feat_topics = feat_topics * mask.unsqueeze(-1) + prob_topics = prob_topics * mask.unsqueeze(-1) + + if (feat_topics.detach().sum(dim=1).sum(dim=0) > 100).sum() <= 3: + logger.warning("topic distribution is highly sparse!") + sampled_topics = self.sample_topic(prob_topics.detach(), feat_topics, L) + if sampled_topics is not None: + updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like(feat1) + s_topics0, s_topics1 = sampled_topics + for k in range(s_topics0.shape[-1]): + topick0, topick1 = s_topics0[..., k], s_topics1[..., k] # [N, L+S] + if (topick0.sum() > 0) and (topick1.sum() > 0): + new_feat0, new_mask0, selected_ids0 = self.reduce_feat(feat0, topick0, N, C) + new_feat1, new_mask1, selected_ids1 = self.reduce_feat(feat1, topick1, N, C) + for idt in range(self.n_iter_topic_transformer): + new_feat0 = self.topic_transformers[idt*2](new_feat0, new_feat0, new_mask0, new_mask0) + new_feat1 = self.topic_transformers[idt*2](new_feat1, new_feat1, new_mask1, new_mask1) + new_feat0 = self.topic_transformers[idt*2+1](new_feat0, new_feat1, new_mask0, new_mask1) + new_feat1 = self.topic_transformers[idt*2+1](new_feat1, new_feat0, new_mask1, new_mask0) + updated_feat0[selected_ids0, :] = new_feat0[new_mask0, :] + updated_feat1[selected_ids1, :] = new_feat1[new_mask1, :] + + feat0 = (1 - s_topics0.sum(dim=-1, keepdim=True)) * feat0 + updated_feat0 + feat1 = (1 - s_topics1.sum(dim=-1, keepdim=True)) * feat1 + updated_feat1 + + conf_matrix = torch.einsum("nlc,nsc->nls", feat0, feat1) / C**.5 #(C * temperature) + if self.training: + topic_matrix = torch.einsum("nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:]) + outlier_mask = torch.einsum("nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:]) + else: + topic_matrix = {"img0": feat_topics[:, :L], "img1": feat_topics[:, L:]} + outlier_mask = torch.ones_like(conf_matrix) + if mask0 is not None: + outlier_mask = (outlier_mask * mask0[..., None] * mask1[:, None]) #.bool() + conf_matrix.masked_fill_(~outlier_mask.bool(), -1e9) + conf_matrix = F.softmax(conf_matrix, 1) * F.softmax(conf_matrix, 2) # * topic_matrix + + return feat0, feat1, conf_matrix, topic_matrix + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) #len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal" + + feat0 = self.layers[0](feat0, feat1, mask0, mask1) + feat1 = self.layers[1](feat1, feat0, mask1, mask0) + + return feat0, feat1 diff --git a/imcui/third_party/TopicFM/src/models/topic_fm.py b/imcui/third_party/TopicFM/src/models/topic_fm.py new file mode 100644 index 0000000000000000000000000000000000000000..95cd22f9b66d08760382fe4cd22c4df918cc9f68 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/topic_fm.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .modules import LocalFeatureTransformer, FinePreprocess, TopicFormer +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching + + +class TopicFM(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + + self.loftr_coarse = TopicFormer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + + feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix}) ###### + + # 3. match coarse-level + self.coarse_matching(data) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/imcui/third_party/TopicFM/src/models/utils/coarse_matching.py b/imcui/third_party/TopicFM/src/models/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..75adbb5cc465220e759a044f96f86c08da2d7a50 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/utils/coarse_matching.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, data): + """ + Args: + data (dict) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + conf_matrix = data['conf_matrix'] + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches diff --git a/imcui/third_party/TopicFM/src/models/utils/fine_matching.py b/imcui/third_party/TopicFM/src/models/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..018f2fe475600b319998c263a97237ce135c3aaf --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/utils/fine_matching.py @@ -0,0 +1,80 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0[:, WW//2, :] + + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) + feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) # [M, C] + heatmap = heatmap.view(-1, W, W) + + # compute coordinates from heatmap + coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords1_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords1_normalized, std.unsqueeze(1)], -1), + 'descriptors0': feat_f0_picked.detach(), 'descriptors1': feat_f1_picked.detach()}) + + # compute absolute kpt coords + self.get_fine_match(coords1_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords1_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + # scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale + mkpts0_f = data['mkpts0_c'] # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale1' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords1_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/imcui/third_party/TopicFM/src/models/utils/geometry.py b/imcui/third_party/TopicFM/src/models/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/utils/geometry.py @@ -0,0 +1,54 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/imcui/third_party/TopicFM/src/models/utils/supervision.py b/imcui/third_party/TopicFM/src/models/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1f0478fdcbe7f8ceffbc4aff4d507cec55bbd2 --- /dev/null +++ b/imcui/third_party/TopicFM/src/models/utils/supervision.py @@ -0,0 +1,151 @@ +from math import log +from loguru import logger + +import torch +from einops import repeat +from kornia.utils import create_meshgrid + +from .geometry import warp_kpts + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['MODEL']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (no depth consistency check, since it leads to worse results experimentally) + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round().long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +@torch.no_grad() +def spvs_fine(data, config): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2]} + """ + # 1. misc + # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') + w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] + scale = config['MODEL']['RESOLUTION'][1] + radius = config['MODEL']['FINE_WINDOW_SIZE'] // 2 + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + # 3. compute gt + scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later + expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + data.update({"expec_f_gt": expec_f_gt}) + + +def compute_supervision_fine(data, config): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config) + else: + raise NotImplementedError diff --git a/imcui/third_party/TopicFM/src/optimizers/__init__.py b/imcui/third_party/TopicFM/src/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f --- /dev/null +++ b/imcui/third_party/TopicFM/src/optimizers/__init__.py @@ -0,0 +1,42 @@ +import torch +from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR + + +def build_optimizer(model, config): + name = config.TRAINER.OPTIMIZER + lr = config.TRAINER.TRUE_LR + + if name == "adam": + return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) + elif name == "adamw": + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) + else: + raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") + + +def build_scheduler(config, optimizer): + """ + Returns: + scheduler (dict):{ + 'scheduler': lr_scheduler, + 'interval': 'step', # or 'epoch' + 'monitor': 'val_f1', (optional) + 'frequency': x, (optional) + } + """ + scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + name = config.TRAINER.SCHEDULER + + if name == 'MultiStepLR': + scheduler.update( + {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) + elif name == 'CosineAnnealing': + scheduler.update( + {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) + elif name == 'ExponentialLR': + scheduler.update( + {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + else: + raise NotImplementedError() + + return scheduler diff --git a/imcui/third_party/TopicFM/src/utils/augment.py b/imcui/third_party/TopicFM/src/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/augment.py @@ -0,0 +1,55 @@ +import albumentations as A + + +class DarkAug(object): + """ + Extreme dark augmentation aiming at Aachen Day-Night + """ + + def __init__(self) -> None: + self.augmentor = A.Compose([ + A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) + ], p=0.75) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +class MobileAug(object): + """ + Random augmentations aiming at images of mobile/handhold devices. + """ + + def __init__(self): + self.augmentor = A.Compose([ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25) + ], p=1.0) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +def build_augmentor(method=None, **kwargs): + if method is not None: + raise NotImplementedError('Using of augmentation functions are not supported yet!') + if method == 'dark': + return DarkAug() + elif method == 'mobile': + return MobileAug() + elif method is None: + return None + else: + raise ValueError(f'Invalid augmentation method: {method}') + + +if __name__ == '__main__': + augmentor = build_augmentor('FDA') diff --git a/imcui/third_party/TopicFM/src/utils/comm.py b/imcui/third_party/TopicFM/src/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167 --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/comm.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +[Copied from detectron2] +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/imcui/third_party/TopicFM/src/utils/dataloader.py b/imcui/third_party/TopicFM/src/utils/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1 --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/dataloader.py @@ -0,0 +1,23 @@ +import numpy as np + + +# --- PL-DATAMODULE --- + +def get_local_split(items: list, world_size: int, rank: int, seed: int): + """ The local rank only loads a split of the dataset. """ + n_items = len(items) + items_permute = np.random.RandomState(seed).permutation(items) + if n_items % world_size == 0: + padded_items = items_permute + else: + padding = np.random.RandomState(seed).choice( + items, + world_size - (n_items % world_size), + replace=True) + padded_items = np.concatenate([items_permute, padding]) + assert len(padded_items) % world_size == 0, \ + f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' + n_per_rank = len(padded_items) // world_size + local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] + + return local_items diff --git a/imcui/third_party/TopicFM/src/utils/dataset.py b/imcui/third_party/TopicFM/src/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..647bbadd821b6c90736ed45462270670b1017b0b --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/dataset.py @@ -0,0 +1,201 @@ +import io +from loguru import logger + +import cv2 +import numpy as np +import h5py +import torch +from numpy.linalg import inv + + +MEGADEPTH_CLIENT = SCANNET_CLIENT = None + +# --- DATA IO --- + +def load_array_from_s3( + path, client, cv_type, + use_h5py=False, +): + byte_str = client.Get(path) + try: + if not use_h5py: + raw_array = np.fromstring(byte_str, np.uint8) + data = cv2.imdecode(raw_array, cv_type) + else: + f = io.BytesIO(byte_str) + data = np.array(h5py.File(f, 'r')['/depth']) + except Exception as ex: + print(f"==> Data loading failure: {path}") + raise ex + + assert data is not None + return data + + +def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ + else cv2.IMREAD_COLOR + if str(path).startswith('s3://'): + image = load_array_from_s3(str(path), client, cv_type) + else: + image = cv2.imread(str(path), cv_type) + + if augment_fn is not None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (h, w) + + +def get_resized_wh(w, h, resize=None): + if (resize is not None) and (max(h,w) > resize): # resize the longer edge + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new, h_new = map(lambda x: int(x // df * df), [w, h]) + else: + w_new, h_new = w, h + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + elif inp.ndim == 3: + padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) + padded[:, :inp.shape[1], :inp.shape[2]] = inp + if ret_mask: + mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) + mask[:, :inp.shape[1], :inp.shape[2]] = True + else: + raise NotImplementedError() + return padded, mask + + +# --- MEGADEPTH --- + +def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) + + # resize image + w, h = image.shape[1], image.shape[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if padding: # padding + pad_to = resize #max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + mask = torch.from_numpy(mask) if mask is not None else None + + return image, mask, scale + + +def read_megadepth_depth(path, pad_to=None): + if str(path).startswith('s3://'): + depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) + else: + depth = np.array(h5py.File(path, 'r')['depth']) + if pad_to is not None: + depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +# --- ScanNet --- + +def read_scannet_gray(path, resize=(640, 480), augment_fn=None): + """ + Args: + resize (tuple): align image to depthmap, in (w, h). + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read and resize image + image = imread_gray(path, augment_fn) + image = cv2.resize(image, resize) + + # (h, w) -> (1, h, w) and normalized + image = torch.from_numpy(image).float()[None] / 255 + return image + + +# ---- evaluation datasets: HLoc, Aachen, InLoc + +def read_img_gray(path, resize=None, down_factor=16): + # read and resize image + image = imread_gray(path, None) + w, h = image.shape[1], image.shape[0] + if (resize is not None) and (max(h, w) > resize): + scale = float(resize / max(h, w)) + w_new, h_new = int(round(w * scale)), int(round(h * scale)) + else: + w_new, h_new = w, h + w_new, h_new = get_divisible_wh(w_new, h_new, down_factor) + image = cv2.resize(image, (w_new, h_new)) + + # (h, w) -> (1, h, w) and normalized + image = torch.from_numpy(image).float()[None] / 255 + scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) + return image, scale + + +def read_scannet_depth(path): + if str(path).startswith('s3://'): + depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) + else: + depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +def read_scannet_pose(path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = inv(cam2world) + return world2cam + + +def read_scannet_intrinsic(path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return intrinsic[:-1, :-1] diff --git a/imcui/third_party/TopicFM/src/utils/metrics.py b/imcui/third_party/TopicFM/src/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a93c31ed1d151cd41e2449a19be2d6abc5f9d419 --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/metrics.py @@ -0,0 +1,193 @@ +import torch +import cv2 +import numpy as np +from collections import OrderedDict +from loguru import logger +from kornia.geometry.epipolar import numeric +from kornia.geometry.conversions import convert_points_to_homogeneous + + +# --- METRICS --- + +def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): + # angle error between 2 vectors + t_gt = T_0to1[:3, 3] + n = np.linalg.norm(t) * np.linalg.norm(t_gt) + t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) + t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity + if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging + t_err = 0 + + # angle error between 2 rotation matrices + R_gt = T_0to1[:3, :3] + cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 + cos = np.clip(cos, -1., 1.) # handle numercial errors + R_err = np.rad2deg(np.abs(np.arccos(cos))) + + return t_err, R_err + + +def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (torch.Tensor): [N, 2] + E (torch.Tensor): [3, 3] + """ + pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + pts0 = convert_points_to_homogeneous(pts0) + pts1 = convert_points_to_homogeneous(pts1) + + Ep0 = pts0 @ E.T # [N, 3] + p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] + Etp1 = pts1 @ E # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + + +def compute_symmetrical_epipolar_errors(data): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) + E_mat = Tx @ data['T_0to1'][:, :3, :3] + + m_bids = data['m_bids'] + pts0 = data['mkpts0_f'] + pts1 = data['mkpts1_f'] + + epi_errs = [] + for bs in range(Tx.size(0)): + mask = m_bids == bs + epi_errs.append( + symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + epi_errs = torch.cat(epi_errs, dim=0) + + data.update({'epi_errs': epi_errs}) + + +def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + if len(kpts0) < 5: + return None + # normalize keypoints + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + # normalize ransac threshold + ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) + + # compute pose with cv2 + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) + if E is None: + print("\nE is None while trying to recover pose.\n") + return None + + # recover pose from E + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + ret = (R, t[:, 0], mask.ravel() > 0) + best_num_inliers = n + + return ret + + +def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] + "t_errs" List[float]: [N] + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr # 0.5 + conf = config.TRAINER.RANSAC_CONF if config is not None else ransac_conf # 0.99999 + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + K0 = data['K0'].cpu().numpy() + K1 = data['K1'].cpu().numpy() + T_0to1 = data['T_0to1'].cpu().numpy() + + for bs in range(K0.shape[0]): + mask = m_bids == bs + ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) + + if ret is None: + data['R_errs'].append(np.inf) + data['t_errs'].append(np.inf) + data['inliers'].append(np.array([]).astype(np.bool)) + else: + R, t, inliers = ret + t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) + data['R_errs'].append(R_err) + data['t_errs'].append(t_err) + data['inliers'].append(inliers) + + +# --- METRIC AGGREGATION --- + +def error_auc(errors, thresholds): + """ + Args: + errors (list): [N,] + thresholds (list) + """ + errors = [0] + sorted(list(errors)) + recall = list(np.linspace(0, 1, len(errors))) + + aucs = [] + thresholds = [5, 10, 20] + for thr in thresholds: + last_index = np.searchsorted(errors, thr) + y = recall[:last_index] + [recall[last_index-1]] + x = errors[:last_index] + [thr] + aucs.append(np.trapz(y, x) / thr) + + return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + + +def epidist_prec(errors, thresholds, ret_dict=False): + precs = [] + for thr in thresholds: + prec_ = [] + for errs in errors: + correct_mask = errs < thr + prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) + precs.append(np.mean(prec_) if len(prec_) > 0 else 0) + if ret_dict: + return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + else: + return precs + + +def aggregate_metrics(metrics, epi_err_thr=5e-4): + """ Aggregate metrics for the whole dataset: + (This method should be called once per dataset) + 1. AUC of the pose error (angular) at the threshold [5, 10, 20] + 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) + """ + # filter duplicates + unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) + unq_ids = list(unq_ids.values()) + logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') + + # pose auc + angular_thresholds = [5, 10, 20] + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) + + # matching precision + dist_thresholds = [epi_err_thr] + precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) + + return {**aucs, **precs} diff --git a/imcui/third_party/TopicFM/src/utils/misc.py b/imcui/third_party/TopicFM/src/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8db04666519753ea2df43903ab6c47ec00a9a1 --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/misc.py @@ -0,0 +1,101 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() + diff --git a/imcui/third_party/TopicFM/src/utils/plotting.py b/imcui/third_party/TopicFM/src/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..89b22ef27e6152225d07ab24bb3e62718d180b59 --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/plotting.py @@ -0,0 +1,313 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib, os, cv2 +import matplotlib.cm as cm +from PIL import Image +import torch.nn.functional as F +import torch + + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth': + thr = 1e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0) # , cmap='gray') + axes[1].imshow(img1) # , cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=5) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=5) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=2) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic'): + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_confidence_figure(data, b_id): + # TODO: Implement confidence figure + raise NotImplementedError() + + +def make_matching_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) + + +np.random.seed(1995) +color_map = np.arange(100) +np.random.shuffle(color_map) + + +def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None): + + topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"] + hw0_c, hw1_c = data["hw0_c"], data["hw1_c"] + hw0_i, hw1_i = data["hw0_i"], data["hw1_i"] + # print(hw0_i, hw1_i) + scale0, scale1 = hw0_i[0] // hw0_c[0], hw1_i[0] // hw1_c[0] + if "scale0" in data: + scale0 *= data["scale0"][0] + else: + scale0 = (scale0, scale0) + if "scale1" in data: + scale1 *= data["scale1"][0] + else: + scale1 = (scale1, scale1) + + n_topics = topic0.shape[-1] + # mask0_nonzero = topic0[0].sum(dim=-1, keepdim=True) > 0 + # mask1_nonzero = topic1[0].sum(dim=-1, keepdim=True) > 0 + theta0 = topic0[0].sum(dim=0) + theta0 /= theta0.sum().float() + theta1 = topic1[0].sum(dim=0) + theta1 /= theta1.sum().float() + # top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics] + # top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics] + top_topics = torch.argsort(theta0*theta1, descending=True)[:show_n_topics] + # print(sum_topic0, sum_topic1) + + topic0 = topic0[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1 # + # topic0[~mask0_nonzero] = -1 + topic1 = topic1[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1 + # topic1[~mask1_nonzero] = -1 + label_img0, label_img1 = torch.zeros_like(topic0) - 1, torch.zeros_like(topic1) - 1 + for i, k in enumerate(top_topics): + label_img0[topic0 == k] = color_map[k] + label_img1[topic1 == k] = color_map[k] + +# print(hw0_c, scale0) +# print(hw1_c, scale1) + # map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0) + map_topic0 = label_img0.float().view(hw0_c).cpu().numpy() #map_topic0.squeeze(0).squeeze(0).cpu().numpy() + map_topic0 = cv2.resize(map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1]))) + # map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1) + map_topic1 = label_img1.float().view(hw1_c).cpu().numpy() #map_topic1.squeeze(0).squeeze(0).cpu().numpy() + map_topic1 = cv2.resize(map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1]))) + + + # show image0 + if saved_name is None: + return map_topic0, map_topic1 + + if not os.path.exists(saved_folder): + os.makedirs(saved_folder) + path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name)) + plt.imshow(img0) + masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0) + plt.imshow(masked_map_topic0, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear') + # plt.show() + plt.axis('off') + plt.savefig(path_saved_img0, bbox_inches='tight', pad_inches=0, dpi=250) + plt.close() + + path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name)) + plt.imshow(img1) + masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1) + plt.imshow(masked_map_topic1, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear') + plt.axis('off') + plt.savefig(path_saved_img1, bbox_inches='tight', pad_inches=0, dpi=250) + plt.close() + + +def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_topics=8, + topic_alpha=0.3, margin=5, path=None, opencv_display=False, opencv_title=''): + topic_map0, topic_map1 = draw_topics(data, img0, img1, show_n_topics=show_n_topics) + + mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(topic_map1 >= 0, axis=-1) + + topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.), cm.jet(topic_map1 / 99.) + topic_cm0 = cv2.cvtColor(topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR) + topic_cm1 = cv2.cvtColor(topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR) + overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32) + overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32) + + cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0) + cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1) + + overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(np.uint8) + + h0, w0 = img0.shape[:2] + h1, w1 = img1.shape[:2] + h, w = h0 * 2 + margin * 2, w0 * 2 + margin + out_fig = 255 * np.ones((h, w, 3), dtype=np.uint8) + out_fig[:h0, :w0] = overlay0 + if h0 >= h1: + start = (h0 - h1) // 2 + out_fig[start:(start+h1), (w0+margin):(w0+margin+w1)] = overlay1 + else: + start = (h1 - h0) // 2 + out_fig[:h0, (w0+margin):(w0+margin+w1)] = overlay1[start:(start+h0)] + + step_h = h0 + margin * 2 + out_fig[step_h:step_h+h0, :w0] = (img0 * 255).astype(np.uint8) + if h0 >= h1: + start = step_h + (h0 - h1) // 2 + out_fig[start:start+h1, (w0+margin):(w0+margin+w1)] = (img1 * 255).astype(np.uint8) + else: + start = (h1 - h0) // 2 + out_fig[step_h:step_h+h0, (w0+margin):(w0+margin+w1)] = (img1[start:start+h0] * 255).astype(np.uint8) + + # draw matching lines, this is inspried from https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py + mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) + mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int) + + for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor): + c = c.tolist() + cv2.line(out_fig, (x0, y0+step_h), (x1+margin+w0, y1+step_h+(h0-h1)//2), + color=c, thickness=1, lineType=cv2.LINE_AA) + # display line end-points as circles + cv2.circle(out_fig, (x0, y0+step_h), 2, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out_fig, (x1+margin+w0, y1+step_h+(h0-h1)//2), 2, c, -1, lineType=cv2.LINE_AA) + + # Scale factor for consistent visualization across scales. + sc = min(h / 960., 2.0) + + # Big text. + Ht = int(30 * sc) # text height + txt_color_fg = (255, 255, 255) + txt_color_bg = (0, 0, 0) + for i, t in enumerate(text): + cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX, + 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA) + + if path is not None: + cv2.imwrite(str(path), out_fig) + + if opencv_display: + cv2.imshow(opencv_title, out_fig) + cv2.waitKey(1) + + return out_fig + + + + + + diff --git a/imcui/third_party/TopicFM/src/utils/profiler.py b/imcui/third_party/TopicFM/src/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9 --- /dev/null +++ b/imcui/third_party/TopicFM/src/utils/profiler.py @@ -0,0 +1,39 @@ +import torch +from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler +from contextlib import contextmanager +from pytorch_lightning.utilities import rank_zero_only + + +class InferenceProfiler(SimpleProfiler): + """ + This profiler records duration of actions with cuda.synchronize() + Use this in test time. + """ + + def __init__(self): + super().__init__() + self.start = rank_zero_only(self.start) + self.stop = rank_zero_only(self.stop) + self.summary = rank_zero_only(self.summary) + + @contextmanager + def profile(self, action_name: str) -> None: + try: + torch.cuda.synchronize() + self.start(action_name) + yield action_name + finally: + torch.cuda.synchronize() + self.stop(action_name) + + +def build_profiler(name): + if name == 'inference': + return InferenceProfiler() + elif name == 'pytorch': + from pytorch_lightning.profiler import PyTorchProfiler + return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) + elif name is None: + return PassThroughProfiler() + else: + raise ValueError(f'Invalid profiler: {name}') diff --git a/imcui/third_party/TopicFM/test.py b/imcui/third_party/TopicFM/test.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb451cde3674b70b0d2e02f37ff1fd391004d30 --- /dev/null +++ b/imcui/third_party/TopicFM/test.py @@ -0,0 +1,68 @@ +import pytorch_lightning as pl +import argparse +import pprint +from loguru import logger as loguru_logger + +from src.config.default import get_cfg_defaults +from src.utils.profiler import build_profiler + +from src.lightning_trainer.data import MultiSceneDataModule +from src.lightning_trainer.trainer import PL_Trainer + + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') + parser.add_argument( + '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") + parser.add_argument( + '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--batch_size', type=int, default=1, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=2) + parser.add_argument( + '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +if __name__ == '__main__': + # parse arguments + args = parse_args() + pprint.pprint(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + pl.seed_everything(config.TRAINER.SEED) # reproducibility + + # tune when testing + if args.thr is not None: + config.MODEL.MATCH_COARSE.THR = args.thr + + loguru_logger.info(f"Args and config initialized!") + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) + loguru_logger.info(f"Model-lightning initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"DataModule initialized!") + + # lightning trainer + trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) + + loguru_logger.info(f"Start testing!") + trainer.test(model, datamodule=data_module, verbose=False) diff --git a/imcui/third_party/TopicFM/train.py b/imcui/third_party/TopicFM/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a552c23718b81ddcb282cedbfe3ceb45e50b3f29 --- /dev/null +++ b/imcui/third_party/TopicFM/train.py @@ -0,0 +1,123 @@ +import math +import argparse +import pprint +from distutils.util import strtobool +from pathlib import Path +from loguru import logger as loguru_logger + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.plugins import DDPPlugin + +from src.config.default import get_cfg_defaults +from src.utils.misc import get_rank_zero_only_logger, setup_gpus +from src.utils.profiler import build_profiler +from src.lightning_trainer.data import MultiSceneDataModule +from src.lightning_trainer.trainer import PL_Trainer + +loguru_logger = get_rank_zero_only_logger(loguru_logger) + + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--exp_name', type=str, default='default_exp_name') + parser.add_argument( + '--batch_size', type=int, default=4, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=4) + parser.add_argument( + '--pin_memory', type=lambda x: bool(strtobool(x)), + nargs='?', default=True, help='whether loading data to pinned memory or not') + parser.add_argument( + '--ckpt_path', type=str, default=None, + help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR') + parser.add_argument( + '--disable_ckpt', action='store_true', + help='disable checkpoint saving (useful for debugging).') + parser.add_argument( + '--profiler_name', type=str, default=None, + help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--parallel_load_data', action='store_true', + help='load datasets in with multiple processes.') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +def main(): + # parse arguments + args = parse_args() + rank_zero_only(pprint.pprint)(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + pl.seed_everything(config.TRAINER.SEED) # reproducibility + # TODO: Use different seeds for each dataloader workers + # This is needed for data augmentation + + # scale lr and warmup-step automatically + args.gpus = _n_gpus = setup_gpus(args.gpus) + config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes + config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size + _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS + config.TRAINER.SCALING = _scaling + config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling + config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling) + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) + loguru_logger.info(f"Model LightningModule initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"Model DataModule initialized!") + + # TensorBoard Logger + logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False) + ckpt_dir = Path(logger.log_dir) / 'checkpoints' + + # Callbacks + # TODO: update ModelCheckpoint to monitor multiple metrics + ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max', + save_last=True, + dirpath=str(ckpt_dir), + filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}') + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks = [lr_monitor] + if not args.disable_ckpt: + callbacks.append(ckpt_callback) + + # Lightning Trainer + trainer = pl.Trainer.from_argparse_args( + args, + plugins=DDPPlugin(find_unused_parameters=False, + num_nodes=args.num_nodes, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), + gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, + callbacks=callbacks, + logger=logger, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, + replace_sampler_ddp=False, # use custom sampler + reload_dataloaders_every_epoch=False, # avoid repeated samples! + weights_summary='full', + profiler=profiler) + loguru_logger.info(f"Trainer initialized!") + loguru_logger.info(f"Start training!") + trainer.fit(model, datamodule=data_module) + + +if __name__ == '__main__': + main() diff --git a/imcui/third_party/TopicFM/visualization.py b/imcui/third_party/TopicFM/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..279b41cd88f61ce3414e2f3077fec642b2c8333a --- /dev/null +++ b/imcui/third_party/TopicFM/visualization.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os, glob, cv2 +import argparse +from argparse import Namespace +import yaml +from tqdm import tqdm +import torch +from torch.utils.data import Dataset, DataLoader, SequentialSampler + +from src.datasets.custom_dataloader import TestDataLoader +from src.utils.dataset import read_img_gray +from configs.data.base import cfg as data_cfg +import viz + + +def get_model_config(method_name, dataset_name, root_dir='viz'): + config_file = f'{root_dir}/configs/{method_name}.yml' + with open(config_file, 'r') as f: + model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name] + return model_conf + + +class DemoDataset(Dataset): + def __init__(self, dataset_dir, img_file=None, resize=0, down_factor=16): + self.dataset_dir = dataset_dir + if img_file is None: + self.list_img_files = glob.glob(os.path.join(dataset_dir, "*.*")) + self.list_img_files.sort() + else: + with open(img_file) as f: + self.list_img_files = [os.path.join(dataset_dir, img_file.strip()) for img_file in f.readlines()] + self.resize = resize + self.down_factor = down_factor + + def __len__(self): + return len(self.list_img_files) + + def __getitem__(self, idx): + img_path = self.list_img_files[idx] #os.path.join(self.dataset_dir, self.list_img_files[idx]) + img, scale = read_img_gray(img_path, resize=self.resize, down_factor=self.down_factor) + return {"img": img, "id": idx, "img_path": img_path} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Visualize matches') + parser.add_argument('--gpu', '-gpu', type=str, default='0') + parser.add_argument('--method', type=str, default=None) + parser.add_argument('--dataset_dir', type=str, default='data/aachen-day-night') + parser.add_argument('--pair_dir', type=str, default=None) + parser.add_argument( + '--dataset_name', type=str, choices=['megadepth', 'scannet', 'aachen_v1.1', 'inloc'], default='megadepth' + ) + parser.add_argument('--measure_time', action="store_true") + parser.add_argument('--no_viz', action="store_true") + parser.add_argument('--compute_eval_metrics', action="store_true") + parser.add_argument('--run_demo', action="store_true") + + args = parser.parse_args() + + model_cfg = get_model_config(args.method, args.dataset_name) + class_name = model_cfg["class"] + model = viz.__dict__[class_name](model_cfg) + # all_args = Namespace(**vars(args), **model_cfg) + if not args.run_demo: + if args.dataset_name == 'megadepth': + from configs.data.megadepth_test_1500 import cfg + + data_cfg.merge_from_other_cfg(cfg) + elif args.dataset_name == 'scannet': + from configs.data.scannet_test_1500 import cfg + + data_cfg.merge_from_other_cfg(cfg) + elif args.dataset_name == 'aachen_v1.1': + data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "aachen_v1.1", + "DATASET.TEST_DATA_ROOT", os.path.join(args.dataset_dir, "images/images_upright"), + "DATASET.TEST_LIST_PATH", args.pair_dir, + "DATASET.TEST_IMGSIZE", model_cfg["imsize"]]) + elif args.dataset_name == 'inloc': + data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "inloc", + "DATASET.TEST_DATA_ROOT", args.dataset_dir, + "DATASET.TEST_LIST_PATH", args.pair_dir, + "DATASET.TEST_IMGSIZE", model_cfg["imsize"]]) + + has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in ["megadepth", "scannet"] + dataloader = TestDataLoader(data_cfg) + with torch.no_grad(): + for data_dict in tqdm(dataloader): + for k, v in data_dict.items(): + if isinstance(v, torch.Tensor): + data_dict[k] = v.cuda() if torch.cuda.is_available() else v + img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT + model.match_and_draw(data_dict, root_dir=img_root_dir, ground_truth=has_ground_truth, + measure_time=args.measure_time, viz_matches=(not args.no_viz)) + + if args.measure_time: + print("Running time for each image is {} miliseconds".format(model.measure_time())) + if args.compute_eval_metrics and has_ground_truth: + model.compute_eval_metrics() + else: + demo_dataset = DemoDataset(args.dataset_dir, img_file=args.pair_dir, resize=640) + sampler = SequentialSampler(demo_dataset) + dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler) + + writer = cv2.VideoWriter('topicfm_demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (640 * 2 + 5, 480 * 2 + 10)) + + model.run_demo(iter(dataloader), writer) #, output_dir="demo", no_display=True) diff --git a/imcui/third_party/TopicFM/viz/__init__.py b/imcui/third_party/TopicFM/viz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0efac33299da6fb8195ce70bcb9eb210f6cf658 --- /dev/null +++ b/imcui/third_party/TopicFM/viz/__init__.py @@ -0,0 +1,3 @@ +from .methods.patch2pix import VizPatch2Pix +from .methods.loftr import VizLoFTR +from .methods.topicfm import VizTopicFM diff --git a/imcui/third_party/TopicFM/viz/configs/__init__.py b/imcui/third_party/TopicFM/viz/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/TopicFM/viz/configs/loftr.yml b/imcui/third_party/TopicFM/viz/configs/loftr.yml new file mode 100644 index 0000000000000000000000000000000000000000..776d625ac8ad5a0b4e4a4e65e2b99f62662bc3fc --- /dev/null +++ b/imcui/third_party/TopicFM/viz/configs/loftr.yml @@ -0,0 +1,18 @@ +default: &default + class: 'VizLoFTR' + ckpt: 'third_party/loftr/pretrained/outdoor_ds.ckpt' + match_threshold: 0.2 +megadepth: + <<: *default +scannet: + <<: *default +hpatch: + <<: *default +inloc: + <<: *default + imsize: 1024 + match_threshold: 0.3 +aachen_v1.1: + <<: *default + imsize: 1024 + match_threshold: 0.3 diff --git a/imcui/third_party/TopicFM/viz/configs/patch2pix.yml b/imcui/third_party/TopicFM/viz/configs/patch2pix.yml new file mode 100644 index 0000000000000000000000000000000000000000..5e3efa7889098425aaf586bd7b88fc28feb74778 --- /dev/null +++ b/imcui/third_party/TopicFM/viz/configs/patch2pix.yml @@ -0,0 +1,19 @@ +default: &default + class: 'VizPatch2Pix' + ckpt: 'third_party/patch2pix/pretrained/patch2pix_pretrained.pth' + ksize: 2 + imsize: 1024 + match_threshold: 0.25 +megadepth: + <<: *default + imsize: 1200 +scannet: + <<: *default + imsize: [640, 480] +hpatch: + <<: *default +inloc: + <<: *default +aachen_v1.1: + <<: *default + imsize: 1024 diff --git a/imcui/third_party/TopicFM/viz/configs/topicfm.yml b/imcui/third_party/TopicFM/viz/configs/topicfm.yml new file mode 100644 index 0000000000000000000000000000000000000000..7a8071a6fcd8def21dbfec5b9b2b10200f494eee --- /dev/null +++ b/imcui/third_party/TopicFM/viz/configs/topicfm.yml @@ -0,0 +1,29 @@ +default: &default + class: 'VizTopicFM' + ckpt: 'pretrained/model_best.ckpt' + match_threshold: 0.2 + n_sampling_topics: 4 + show_n_topics: 4 +megadepth: + <<: *default + n_sampling_topics: 10 + show_n_topics: 6 +scannet: + <<: *default + match_threshold: 0.3 + n_sampling_topics: 5 + show_n_topics: 4 +hpatch: + <<: *default +inloc: + <<: *default + imsize: 1024 + match_threshold: 0.3 + n_sampling_topics: 8 + show_n_topics: 4 +aachen_v1.1: + <<: *default + imsize: 1024 + match_threshold: 0.3 + n_sampling_topics: 6 + show_n_topics: 6 diff --git a/imcui/third_party/TopicFM/viz/methods/__init__.py b/imcui/third_party/TopicFM/viz/methods/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/TopicFM/viz/methods/base.py b/imcui/third_party/TopicFM/viz/methods/base.py new file mode 100644 index 0000000000000000000000000000000000000000..377e95134f339459bff3c5a0d30b3bfbc122d978 --- /dev/null +++ b/imcui/third_party/TopicFM/viz/methods/base.py @@ -0,0 +1,59 @@ +import pprint +from abc import ABCMeta, abstractmethod +import torch +from itertools import chain + +from src.utils.plotting import make_matching_figure, error_colormap +from src.utils.metrics import aggregate_metrics + + +def flatten_list(x): + return list(chain(*x)) + + +class Viz(metaclass=ABCMeta): + def __init__(self): + super().__init__() + self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu') + torch.set_grad_enabled(False) + + # for evaluation metrics of MegaDepth and ScanNet + self.eval_stats = [] + self.time_stats = [] + + def draw_matches(self, mkpts0, mkpts1, img0, img1, conf, path=None, **kwargs): + thr = 5e-4 + # mkpts0 = pe['mkpts0_f'].cpu().numpy() + # mkpts1 = pe['mkpts1_f'].cpu().numpy() + if "conf_thr" in kwargs: + thr = kwargs["conf_thr"] + color = error_colormap(conf, thr, alpha=0.1) + + text = [ + f"{self.name}", + f"#Matches: {len(mkpts0)}", + ] + if 'R_errs' in kwargs: + text.append(f"$\\Delta$R:{kwargs['R_errs']:.2f}°, $\\Delta$t:{kwargs['t_errs']:.2f}°",) + + if path: + make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150) + else: + return make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text) + + @abstractmethod + def match_and_draw(self, data_dict, **kwargs): + pass + + def compute_eval_metrics(self, epi_err_thr=5e-4): + # metrics: dict of list, numpy + _metrics = [o['metrics'] for o in self.eval_stats] + metrics = {k: flatten_list([_me[k] for _me in _metrics]) for k in _metrics[0]} + + val_metrics_4tb = aggregate_metrics(metrics, epi_err_thr) + print('\n' + pprint.pformat(val_metrics_4tb)) + + def measure_time(self): + if len(self.time_stats) == 0: + return 0 + return sum(self.time_stats) / len(self.time_stats) diff --git a/imcui/third_party/TopicFM/viz/methods/loftr.py b/imcui/third_party/TopicFM/viz/methods/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..53d0c00c1a067cee10bf1587197e4780ac8b2eda --- /dev/null +++ b/imcui/third_party/TopicFM/viz/methods/loftr.py @@ -0,0 +1,85 @@ +from argparse import Namespace +import os +import torch +import cv2 + +from .base import Viz +from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors + +from third_party.loftr.src.loftr import LoFTR, default_cfg + + +class VizLoFTR(Viz): + def __init__(self, args): + super().__init__() + if type(args) == dict: + args = Namespace(**args) + + self.match_threshold = args.match_threshold + + # Load model + conf = dict(default_cfg) + conf['match_coarse']['thr'] = self.match_threshold + print(conf) + self.model = LoFTR(config=conf) + ckpt_dict = torch.load(args.ckpt) + self.model.load_state_dict(ckpt_dict['state_dict']) + self.model = self.model.eval().to(self.device) + + # Name the method + # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0] + self.name = 'LoFTR' + + print(f'Initialize {self.name}') + + def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True): + if measure_time: + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + self.model(data_dict) + if measure_time: + torch.cuda.synchronize() + end.record() + torch.cuda.synchronize() + self.time_stats.append(start.elapsed_time(end)) + + kpts0 = data_dict['mkpts0_f'].cpu().numpy() + kpts1 = data_dict['mkpts1_f'].cpu().numpy() + + img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0] + img0 = cv2.imread(os.path.join(root_dir, img_name0)) + img1 = cv2.imread(os.path.join(root_dir, img_name1)) + if str(data_dict["dataset_name"][0]).lower() == 'scannet': + img0 = cv2.resize(img0, (640, 480)) + img1 = cv2.resize(img1, (640, 480)) + + if viz_matches: + saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]]) + folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name)) + if not os.path.exists(folder_matches): + os.makedirs(folder_matches) + path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name)) + if ground_truth: + compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match + compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair + epi_errors = data_dict['epi_errs'].cpu().numpy() + R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0] + + self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches, + R_errs=R_errors, t_errs=t_errors) + + rel_pair_names = list(zip(*data_dict['pair_names'])) + bs = data_dict['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'R_errs': data_dict['R_errs'], + 't_errs': data_dict['t_errs'], + 'inliers': data_dict['inliers']} + self.eval_stats.append({'metrics': metrics}) + else: + m_conf = 1 - data_dict["mconf"].cpu().numpy() + self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4) diff --git a/imcui/third_party/TopicFM/viz/methods/patch2pix.py b/imcui/third_party/TopicFM/viz/methods/patch2pix.py new file mode 100644 index 0000000000000000000000000000000000000000..14a1d345881e2021be97dc5dde91d8bbe1cd18fa --- /dev/null +++ b/imcui/third_party/TopicFM/viz/methods/patch2pix.py @@ -0,0 +1,80 @@ +from argparse import Namespace +import os, sys +import torch +import cv2 +from pathlib import Path + +from .base import Viz +from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors + +patch2pix_path = Path(__file__).parent / '../../third_party/patch2pix' +sys.path.append(str(patch2pix_path)) +from third_party.patch2pix.utils.eval.model_helper import load_model, estimate_matches + + +class VizPatch2Pix(Viz): + def __init__(self, args): + super().__init__() + + if type(args) == dict: + args = Namespace(**args) + self.imsize = args.imsize + self.match_threshold = args.match_threshold + self.ksize = args.ksize + self.model = load_model(args.ckpt, method='patch2pix') + self.name = 'Patch2Pix' + print(f'Initialize {self.name} with image size {self.imsize}') + + def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True): + img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0] + path_img0 = os.path.join(root_dir, img_name0) + path_img1 = os.path.join(root_dir, img_name1) + img0, img1 = cv2.imread(path_img0), cv2.imread(path_img1) + return_m_upscale = True + if str(data_dict["dataset_name"][0]).lower() == 'scannet': + # self.imsize = 640 + img0 = cv2.resize(img0, tuple(self.imsize)) # (640, 480)) + img1 = cv2.resize(img1, tuple(self.imsize)) # (640, 480)) + return_m_upscale = False + outputs = estimate_matches(self.model, path_img0, path_img1, + ksize=self.ksize, io_thres=self.match_threshold, + eval_type='fine', imsize=self.imsize, + return_upscale=return_m_upscale, measure_time=measure_time) + if measure_time: + self.time_stats.append(outputs[-1]) + matches, mconf = outputs[0], outputs[1] + kpts0 = matches[:, :2] + kpts1 = matches[:, 2:4] + + if viz_matches: + saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]]) + folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name)) + if not os.path.exists(folder_matches): + os.makedirs(folder_matches) + path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name)) + + if ground_truth: + data_dict["mkpts0_f"] = torch.from_numpy(matches[:, :2]).float().to(self.device) + data_dict["mkpts1_f"] = torch.from_numpy(matches[:, 2:4]).float().to(self.device) + data_dict["m_bids"] = torch.zeros(matches.shape[0], device=self.device, dtype=torch.float32) + compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match + compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair + epi_errors = data_dict['epi_errs'].cpu().numpy() + R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0] + + self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches, + R_errs=R_errors, t_errs=t_errors) + + rel_pair_names = list(zip(*data_dict['pair_names'])) + bs = data_dict['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'R_errs': data_dict['R_errs'], + 't_errs': data_dict['t_errs'], + 'inliers': data_dict['inliers']} + self.eval_stats.append({'metrics': metrics}) + else: + m_conf = 1 - mconf + self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4) diff --git a/imcui/third_party/TopicFM/viz/methods/topicfm.py b/imcui/third_party/TopicFM/viz/methods/topicfm.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8b1485d5296947a38480cc031c5d7439bf163d --- /dev/null +++ b/imcui/third_party/TopicFM/viz/methods/topicfm.py @@ -0,0 +1,198 @@ +from argparse import Namespace +import os +import torch +import cv2 +from time import time +from pathlib import Path +import matplotlib.cm as cm +import numpy as np + +from src.models.topic_fm import TopicFM +from src import get_model_cfg +from .base import Viz +from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors +from src.utils.plotting import draw_topics, draw_topicfm_demo, error_colormap + + +class VizTopicFM(Viz): + def __init__(self, args): + super().__init__() + if type(args) == dict: + args = Namespace(**args) + + self.match_threshold = args.match_threshold + self.n_sampling_topics = args.n_sampling_topics + self.show_n_topics = args.show_n_topics + + # Load model + conf = dict(get_model_cfg()) + conf['match_coarse']['thr'] = self.match_threshold + conf['coarse']['n_samples'] = self.n_sampling_topics + print("model config: ", conf) + self.model = TopicFM(config=conf) + ckpt_dict = torch.load(args.ckpt) + self.model.load_state_dict(ckpt_dict['state_dict']) + self.model = self.model.eval().to(self.device) + + # Name the method + # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0] + self.name = 'TopicFM' + + print(f'Initialize {self.name}') + + def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True): + if measure_time: + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + self.model(data_dict) + if measure_time: + torch.cuda.synchronize() + end.record() + torch.cuda.synchronize() + self.time_stats.append(start.elapsed_time(end)) + + kpts0 = data_dict['mkpts0_f'].cpu().numpy() + kpts1 = data_dict['mkpts1_f'].cpu().numpy() + + img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0] + img0 = cv2.imread(os.path.join(root_dir, img_name0)) + img1 = cv2.imread(os.path.join(root_dir, img_name1)) + if str(data_dict["dataset_name"][0]).lower() == 'scannet': + img0 = cv2.resize(img0, (640, 480)) + img1 = cv2.resize(img1, (640, 480)) + + if viz_matches: + saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]]) + folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name)) + if not os.path.exists(folder_matches): + os.makedirs(folder_matches) + path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name)) + + if ground_truth: + compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match + compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair + epi_errors = data_dict['epi_errs'].cpu().numpy() + R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0] + + self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches, + R_errs=R_errors, t_errs=t_errors) + + # compute evaluation metrics + rel_pair_names = list(zip(*data_dict['pair_names'])) + bs = data_dict['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'R_errs': data_dict['R_errs'], + 't_errs': data_dict['t_errs'], + 'inliers': data_dict['inliers']} + self.eval_stats.append({'metrics': metrics}) + else: + m_conf = 1 - data_dict["mconf"].cpu().numpy() + self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4) + if self.show_n_topics > 0: + folder_topics = os.path.join(root_dir, "{}_viz_topics".format(self.name)) + if not os.path.exists(folder_topics): + os.makedirs(folder_topics) + draw_topics(data_dict, img0, img1, saved_folder=folder_topics, show_n_topics=self.show_n_topics, + saved_name=saved_name) + + def run_demo(self, dataloader, writer=None, output_dir=None, no_display=False, skip_frames=1): + data_dict = next(dataloader) + + frame_id = 0 + last_image_id = 0 + img0 = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255 + frame_tensor = data_dict["img"].to(self.device) + pair_data = {'image0': frame_tensor} + last_frame = cv2.resize(img0, (frame_tensor.shape[-1], frame_tensor.shape[-2]), cv2.INTER_LINEAR) + + if output_dir is not None: + print('==> Will write outputs to {}'.format(output_dir)) + Path(output_dir).mkdir(exist_ok=True) + + # Create a window to display the demo. + if not no_display: + window_name = 'Topic-assisted Feature Matching' + cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) + cv2.resizeWindow(window_name, (640 * 2, 480 * 2)) + else: + print('Skipping visualization, will not show a GUI.') + + # Print the keyboard help menu. + print('==> Keyboard control:\n' + '\tn: select the current frame as the reference image (left)\n' + '\tq: quit') + + # vis_range = [kwargs["bottom_k"], kwargs["top_k"]] + + while True: + frame_id += 1 + if frame_id == len(dataloader): + print('Finished demo_loftr.py') + break + data_dict = next(dataloader) + if frame_id % skip_frames != 0: + # print("Skipping frame.") + continue + + stem0, stem1 = last_image_id, data_dict["id"][0].item() - 1 + frame = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255 + + frame_tensor = data_dict["img"].to(self.device) + frame = cv2.resize(frame, (frame_tensor.shape[-1], frame_tensor.shape[-2]), interpolation=cv2.INTER_LINEAR) + pair_data = {**pair_data, 'image1': frame_tensor} + self.model(pair_data) + + total_n_matches = len(pair_data['mkpts0_f']) + mkpts0 = pair_data['mkpts0_f'].cpu().numpy() # [vis_range[0]:vis_range[1]] + mkpts1 = pair_data['mkpts1_f'].cpu().numpy() # [vis_range[0]:vis_range[1]] + mconf = pair_data['mconf'].cpu().numpy() # [vis_range[0]:vis_range[1]] + + # Normalize confidence. + if len(mconf) > 0: + mconf = 1 - mconf + + # alpha = 0 + # color = cm.jet(mconf, alpha=alpha) + color = error_colormap(mconf, thr=0.4, alpha=0.1) + + text = [ + f'Topics', + '#Matches: {}'.format(total_n_matches), + ] + + out = draw_topicfm_demo(pair_data, last_frame, frame, mkpts0, mkpts1, color, text, + show_n_topics=4, path=None) + + if not no_display: + if writer is not None: + writer.write(out) + cv2.imshow('TopicFM Matches', out) + key = chr(cv2.waitKey(10) & 0xFF) + if key == 'q': + if writer is not None: + writer.release() + print('Exiting...') + break + elif key == 'n': + pair_data['image0'] = frame_tensor + last_frame = frame + last_image_id = (data_dict["id"][0].item() - 1) + frame_id_left = frame_id + + elif output_dir is not None: + stem = 'matches_{:06}_{:06}'.format(stem0, stem1) + out_file = str(Path(output_dir, stem + '.png')) + print('\nWriting image to {}'.format(out_file)) + cv2.imwrite(out_file, out) + else: + raise ValueError("output_dir is required when no display is given.") + + cv2.destroyAllWindows() + if writer is not None: + writer.release() + diff --git a/imcui/third_party/XoFTR/configs/data/__init__.py b/imcui/third_party/XoFTR/configs/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/XoFTR/configs/data/base.py b/imcui/third_party/XoFTR/configs/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..03aab160fa4137ccc04380f94854a56fbb549074 --- /dev/null +++ b/imcui/third_party/XoFTR/configs/data/base.py @@ -0,0 +1,35 @@ +""" +The data config will be the last one merged into the main config. +Setups in data configs will override all existed setups! +""" + +from yacs.config import CfgNode as CN +_CN = CN() +_CN.DATASET = CN() +_CN.TRAINER = CN() + +# training data config +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +# validation set config +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None +_CN.DATASET.VAL_INTRINSIC_PATH = None + +# testing data config +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None +_CN.DATASET.TEST_INTRINSIC_PATH = None + +# dataset config +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +cfg = _CN diff --git a/imcui/third_party/XoFTR/configs/data/megadepth_trainval_840.py b/imcui/third_party/XoFTR/configs/data/megadepth_trainval_840.py new file mode 100644 index 0000000000000000000000000000000000000000..130212c3e7d55310cb822a37e026655aff9c346f --- /dev/null +++ b/imcui/third_party/XoFTR/configs/data/megadepth_trainval_840.py @@ -0,0 +1,22 @@ +from configs.data.base import cfg + + +TRAIN_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train" +cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" +cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 + +TEST_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" +cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" +cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +# 368 scenes in total for MegaDepth +# (with difficulty balanced (further split each scene to 3 sub-scenes)) +cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100 + +cfg.DATASET.MGDPT_IMG_RESIZE = 840 # for training on 32GB meme GPUs diff --git a/imcui/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py b/imcui/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py new file mode 100644 index 0000000000000000000000000000000000000000..ab33112a8abc8a6c89993bed240f98bd51d38e2d --- /dev/null +++ b/imcui/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py @@ -0,0 +1,23 @@ +from configs.data.base import cfg + + +TRAIN_BASE_PATH = "data/megadepth/index" +cfg.DATASET.TRAIN_DATA_SOURCE = "MegaDepth" +cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train" +cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" +cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 + +VAL_BASE_PATH = "data/METU_VisTIR/index" +cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" +cfg.DATASET.VAL_DATA_SOURCE = "VisTir" +cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/METU_VisTIR" +cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{VAL_BASE_PATH}/scene_info_val" +cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{VAL_BASE_PATH}/val_test_list/val_list.txt" +cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val + +# 368 scenes in total for MegaDepth +# (with difficulty balanced (further split each scene to 3 sub-scenes)) +cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100 + +cfg.DATASET.MGDPT_IMG_RESIZE = 640 # for training on 11GB mem GPUs diff --git a/imcui/third_party/XoFTR/configs/data/pretrain.py b/imcui/third_party/XoFTR/configs/data/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..0dba89ca2623985e51890e04a462f46a492395f9 --- /dev/null +++ b/imcui/third_party/XoFTR/configs/data/pretrain.py @@ -0,0 +1,8 @@ +from configs.data.base import cfg + +cfg.DATASET.TRAIN_DATA_SOURCE = "KAIST" +cfg.DATASET.TRAIN_DATA_ROOT = "data/kaist-cvpr15" +cfg.DATASET.VAL_DATA_SOURCE = "KAIST" +cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/kaist-cvpr15" + +cfg.DATASET.PRETRAIN_IMG_RESIZE = 640 diff --git a/imcui/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py b/imcui/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py new file mode 100644 index 0000000000000000000000000000000000000000..61aa1a96853671638223ecb65075ea32675e78e9 --- /dev/null +++ b/imcui/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py @@ -0,0 +1,17 @@ +from src.config.default import _CN as cfg + +cfg.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' + +cfg.TRAINER.CANONICAL_LR = 8e-3 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 +cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24, 30, 36, 42] + +# pose estimation +cfg.TRAINER.RANSAC_PIXEL_THR = 1.5 + +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 +cfg.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 + +cfg.TRAINER.USE_WANDB = True # use weight and biases diff --git a/imcui/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py b/imcui/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..252146af93224a4449dbcaf13b7274573c3aba16 --- /dev/null +++ b/imcui/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py @@ -0,0 +1,12 @@ +from src.config.default import _CN as cfg + +cfg.TRAINER.CANONICAL_LR = 4e-3 +cfg.TRAINER.WARMUP_STEP = 1250 # 2 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 +cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18] + +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 + +cfg.TRAINER.USE_WANDB = True # use weight and biases + diff --git a/imcui/third_party/XoFTR/environment.yaml b/imcui/third_party/XoFTR/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6a25b9b6d1bc38d38010da55efb309901a0a1d2 --- /dev/null +++ b/imcui/third_party/XoFTR/environment.yaml @@ -0,0 +1,14 @@ +name: xoftr +channels: + # - https://dx-mirrors.sensetime.com/anaconda/cloud/pytorch + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.8 + - pytorch=2.0.1 + - pytorch-cuda=11.8 + - pip + - pip: + - -r requirements.txt diff --git a/imcui/third_party/XoFTR/pretrain.py b/imcui/third_party/XoFTR/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..d861315931714044e6915c5b2043218e06aed913 --- /dev/null +++ b/imcui/third_party/XoFTR/pretrain.py @@ -0,0 +1,125 @@ +import math +import argparse +import pprint +from distutils.util import strtobool +from pathlib import Path +from loguru import logger as loguru_logger +from datetime import datetime + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.plugins import DDPPlugin + +from src.config.default import get_cfg_defaults +from src.utils.misc import get_rank_zero_only_logger, setup_gpus +from src.utils.profiler import build_profiler +from src.lightning.data_pretrain import PretrainDataModule +from src.lightning.lightning_xoftr_pretrain import PL_XoFTR_Pretrain + +loguru_logger = get_rank_zero_only_logger(loguru_logger) + + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--exp_name', type=str, default='default_exp_name') + parser.add_argument( + '--batch_size', type=int, default=4, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=4) + parser.add_argument( + '--pin_memory', type=lambda x: bool(strtobool(x)), + nargs='?', default=True, help='whether loading data to pinned memory or not') + parser.add_argument( + '--ckpt_path', type=str, default=None, + help='pretrained checkpoint path') + parser.add_argument( + '--disable_ckpt', action='store_true', + help='disable checkpoint saving (useful for debugging).') + parser.add_argument( + '--profiler_name', type=str, default=None, + help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--parallel_load_data', action='store_true', + help='load datasets in with multiple processes.') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +def main(): + # parse arguments + args = parse_args() + rank_zero_only(pprint.pprint)(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + pl.seed_everything(config.TRAINER.SEED) # reproducibility + + # scale lr and warmup-step automatically + args.gpus = _n_gpus = setup_gpus(args.gpus) + config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes + config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size + _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS + config.TRAINER.SCALING = _scaling + config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling + config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling) + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_XoFTR_Pretrain(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) + loguru_logger.info(f"XoFTR LightningModule initialized!") + + # lightning data + data_module = PretrainDataModule(args, config) + loguru_logger.info(f"XoFTR DataModule initialized!") + + # TensorBoard Logger + logger = [TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)] + ckpt_dir = Path(logger[0].log_dir) / 'checkpoints' + if config.TRAINER.USE_WANDB: + logger.append(WandbLogger(name=args.exp_name + f"_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + project='XoFTR')) + + # Callbacks + # TODO: update ModelCheckpoint to monitor multiple metrics + ckpt_callback = ModelCheckpoint(verbose=True, save_top_k=-1, + save_last=True, + dirpath=str(ckpt_dir), + filename='{epoch}') + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks = [lr_monitor] + if not args.disable_ckpt: + callbacks.append(ckpt_callback) + + # Lightning Trainer + trainer = pl.Trainer.from_argparse_args( + args, + plugins=DDPPlugin(find_unused_parameters=True, + num_nodes=args.num_nodes, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), + gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, + callbacks=callbacks, + logger=logger, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, + replace_sampler_ddp=False, # use custom sampler + reload_dataloaders_every_epoch=False, # avoid repeated samples! + weights_summary='full', + profiler=profiler) + loguru_logger.info(f"Trainer initialized!") + loguru_logger.info(f"Start training!") + trainer.fit(model, datamodule=data_module) + + +if __name__ == '__main__': + main() diff --git a/imcui/third_party/XoFTR/src/__init__.py b/imcui/third_party/XoFTR/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/XoFTR/src/config/default.py b/imcui/third_party/XoFTR/src/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..02091605a350d9fc8c0dd7510a3937a67e599593 --- /dev/null +++ b/imcui/third_party/XoFTR/src/config/default.py @@ -0,0 +1,203 @@ +from yacs.config import CfgNode as CN + +INFERENCE = False + +_CN = CN() + +############## ↓ XoFTR Pipeline ↓ ############## +_CN.XOFTR = CN() +_CN.XOFTR.RESOLUTION = (8, 2) # options: [(8, 2)] +_CN.XOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.XOFTR.MEDIUM_WINDOW_SIZE = 3 # window_size in fine_level, must be odd + +# 1. XoFTR-backbone (local feature CNN) config +_CN.XOFTR.RESNET = CN() +_CN.XOFTR.RESNET.INITIAL_DIM = 128 +_CN.XOFTR.RESNET.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 + +# 2. XoFTR-coarse module config +_CN.XOFTR.COARSE = CN() +_CN.XOFTR.COARSE.INFERENCE = INFERENCE +_CN.XOFTR.COARSE.D_MODEL = 256 +_CN.XOFTR.COARSE.D_FFN = 256 +_CN.XOFTR.COARSE.NHEAD = 8 +_CN.XOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.XOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] + +# 3. Coarse-Matching config +_CN.XOFTR.MATCH_COARSE = CN() +_CN.XOFTR.MATCH_COARSE.INFERENCE = INFERENCE +_CN.XOFTR.MATCH_COARSE.D_MODEL = 256 +_CN.XOFTR.MATCH_COARSE.THR = 0.3 +_CN.XOFTR.MATCH_COARSE.BORDER_RM = 2 +_CN.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax'] +_CN.XOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.XOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. XoFTR-fine module config +_CN.XOFTR.FINE = CN() +_CN.XOFTR.FINE.DENSER = False # if true, match all features in fine-level windows +_CN.XOFTR.FINE.INFERENCE = INFERENCE +_CN.XOFTR.FINE.DSMAX_TEMPERATURE = 0.1 +_CN.XOFTR.FINE.THR = 0.1 +_CN.XOFTR.FINE.MLP_HIDDEN_DIM_COEF = 2 # coef for mlp hidden dim (hidden_dim = feat_dim * coef) +_CN.XOFTR.FINE.NHEAD_FINE_LEVEL = 8 +_CN.XOFTR.FINE.NHEAD_MEDIUM_LEVEL = 7 + + +# 5. XoFTR Losses + +_CN.XOFTR.LOSS = CN() +_CN.XOFTR.LOSS.FOCAL_ALPHA = 0.25 +_CN.XOFTR.LOSS.FOCAL_GAMMA = 2.0 +_CN.XOFTR.LOSS.POS_WEIGHT = 1.0 +_CN.XOFTR.LOSS.NEG_WEIGHT = 1.0 + +# -- # coarse-level +_CN.XOFTR.LOSS.COARSE_WEIGHT = 0.5 +# -- # fine-level +_CN.XOFTR.LOSS.FINE_WEIGHT = 0.3 +# -- # sub-pixel +_CN.XOFTR.LOSS.SUB_WEIGHT = 1 * 10**4 + +############## Dataset ############## +_CN.DATASET = CN() +# 1. data config +# training and validating +_CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +_CN.DATASET.VAL_DATA_SOURCE = None +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file +_CN.DATASET.VAL_INTRINSIC_PATH = None +# testing +_CN.DATASET.TEST_DATA_SOURCE = None +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file +_CN.DATASET.TEST_INTRINSIC_PATH = None + +# 2. dataset config +# general options +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 +_CN.DATASET.AUGMENTATION_TYPE = "rgb_thermal" # options: [None, 'dark', 'mobile'] + +# MegaDepth options +_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE +_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 +_CN.DATASET.MGDPT_DF = 8 + +# VisTir options +_CN.DATASET.VISTIR_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.VISTIR_IMG_PAD = False # pad img to square with size = VISTIR_IMG_RESIZE +_CN.DATASET.VISTIR_DF = 8 + +# Pretrain dataset options +_CN.DATASET.PRETRAIN_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.PRETRAIN_IMG_PAD = True # pad img to square with size = PRETRAIN_IMG_RESIZE +_CN.DATASET.PRETRAIN_DF = 8 +_CN.DATASET.PRETRAIN_FRAME_GAP = 2 # the gap between video frames of Kaist dataset + +############## Trainer ############## +_CN.TRAINER = CN() +_CN.TRAINER.WORLD_SIZE = 1 +_CN.TRAINER.CANONICAL_BS = 64 +_CN.TRAINER.CANONICAL_LR = 6e-3 +_CN.TRAINER.SCALING = None # this will be calculated automatically +_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning + +_CN.TRAINER.USE_WANDB = False # use weight and biases + +# optimizer +_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] +_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime +_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAMW_DECAY = 0.1 + +# step-based warm-up +_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_STEP = 4800 + +# learning rate scheduler +_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR +_CN.TRAINER.MSLR_GAMMA = 0.5 +_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing +_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval + +# plotting related +_CN.TRAINER.ENABLE_PLOTTING = True +_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 128 # number of val/test paris for plotting +_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] +_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' + +# geometric metrics and pose solver +_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.RANSAC_PIXEL_THR = 0.5 +_CN.TRAINER.RANSAC_CONF = 0.99999 +_CN.TRAINER.RANSAC_MAX_ITERS = 10000 +_CN.TRAINER.USE_MAGSACPP = False + +# data sampler for train_dataloader +_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] +# 'scene_balance' config +_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 +_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not +_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not +_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data +# 'random' config +_CN.TRAINER.RDM_REPLACEMENT = True +_CN.TRAINER.RDM_NUM_SAMPLES = None + +# gradient clipping +_CN.TRAINER.GRADIENT_CLIPPING = 0.5 + +# reproducibility +# This seed affects the data sampling. With the same seed, the data sampling is promised +# to be the same. When resume training from a checkpoint, it's better to use a different +# seed, otherwise the sampled data will be exactly the same as before resuming, which will +# cause less unique data items sampled during the entire training. +# Use of different seed values might affect the final training result, since not all data items +# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.) +_CN.TRAINER.SEED = 66 + +############## Pretrain ############## +_CN.PRETRAIN = CN() +_CN.PRETRAIN.PATCH_SIZE = 64 # patch sıze for masks +_CN.PRETRAIN.MASK_RATIO = 0.5 +_CN.PRETRAIN.MAE_MARGINS = [0, 0.4, 0, 0] # margins not to be masked (up bottom left right) +_CN.PRETRAIN.VAL_SEED = 42 # rng seed to crate the same masks for validation + +_CN.XOFTR.PRETRAIN_PATCH_SIZE = _CN.PRETRAIN.PATCH_SIZE + +############## Test/Inference ############## +_CN.TEST = CN() +_CN.TEST.IMG0_RESIZE = 640 # resize the longer side +_CN.TEST.IMG1_RESIZE = 640 # resize the longer side +_CN.TEST.DF = 8 +_CN.TEST.PADDING = False # pad img to square with size = IMG0_RESIZE, IMG1_RESIZE +_CN.TEST.COARSE_SCALE = 0.125 + +def get_cfg_defaults(inference=False): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + if inference: + _CN.XOFTR.COARSE.INFERENCE = True + _CN.XOFTR.MATCH_COARSE.INFERENCE = True + _CN.XOFTR.FINE.INFERENCE = True + return _CN.clone() diff --git a/imcui/third_party/XoFTR/src/datasets/megadepth.py b/imcui/third_party/XoFTR/src/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..32eaeb7a554ed28a956c71ca9c0df5e418aebbba --- /dev/null +++ b/imcui/third_party/XoFTR/src/datasets/megadepth.py @@ -0,0 +1,143 @@ +import os.path as osp +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from loguru import logger + +from src.utils.dataset import read_megadepth_gray, read_megadepth_depth + +def correct_image_paths(scene_info): + """Changes the path format from undistorted images from D2Net to MegaDepth_v1 format""" + image_paths = scene_info["image_paths"] + for ii in range(len(image_paths)): + if image_paths[ii] is not None: + folds = image_paths[ii].split("/") + path = osp.join("phoenix/S6/zl548/MegaDepth_v1/", folds[1], "dense0/imgs", folds[3] ) + image_paths[ii] = path + scene_info["image_paths"] = image_paths + return scene_info + +class MegaDepthDataset(Dataset): + def __init__(self, + root_dir, + npz_path, + mode='train', + min_overlap_score=0.4, + img_resize=None, + df=None, + img_padding=False, + depth_padding=False, + augment_fn=None, + **kwargs): + """ + Manage one scene(npz_path) of MegaDepth dataset. + + Args: + root_dir (str): megadepth root directory that has `phoenix`. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + mode (str): options are ['train', 'val', 'test'] + min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. + img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. + This is useful during training with batches and testing with memory intensive algorithms. + df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. + img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. + depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. + augment_fn (callable, optional): augments images with pre-defined visual effects. + """ + super().__init__() + self.root_dir = root_dir + self.mode = mode + self.scene_id = npz_path.split('.')[0] + + # prepare scene_info and pair_info + if mode == 'test' and min_overlap_score != 0: + logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") + min_overlap_score = 0 + self.scene_info = np.load(npz_path, allow_pickle=True) + self.scene_info = correct_image_paths(self.scene_info) + self.pair_infos = self.scene_info['pair_infos'].copy() + del self.scene_info['pair_infos'] + self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] + + # parameters for image resizing, padding and depthmap padding + if mode == 'train': + assert img_resize is not None and img_padding and depth_padding + self.img_resize = img_resize + self.df = df + self.img_padding = img_padding + self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. + + # for training XoFTR + # self.augment_fn = augment_fn if mode == 'train' else None + self.augment_fn = augment_fn + self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + + def __len__(self): + return len(self.pair_infos) + + def __getitem__(self, idx): + (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] + + # read grayscale image and mask. (1, h, w) and (h, w) + img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) + img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) + + if getattr(self.augment_fn, 'random_switch', False): + im_num = torch.randint(0, 2, (1,)) + augment_fn_0 = lambda x: self.augment_fn(x, image_num=im_num) + augment_fn_1 = lambda x: self.augment_fn(x, image_num=1-im_num) + else: + augment_fn_0 = self.augment_fn + augment_fn_1 = self.augment_fn + image0, mask0, scale0 = read_megadepth_gray( + img_name0, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_0) + image1, mask1, scale1 = read_megadepth_gray( + img_name1, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_1) + + # read depth. shape: (h, w) + if self.mode in ['train', 'val']: + depth0 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) + depth1 = read_megadepth_depth( + osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) + else: + depth0 = depth1 = torch.tensor([]) + + # read intrinsics of original size + K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) + K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T0 = self.scene_info['poses'][idx0] + T1 = self.scene_info['poses'][idx1] + T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) + T_1to0 = T_0to1.inverse() + + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'MegaDepth', + 'scene_id': self.scene_id, + 'pair_id': idx, + 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), + } + + # for XoFTR training + if mask0 is not None: # img_padding is True + if self.coarse_scale: + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/XoFTR/src/datasets/pretrain_dataset.py b/imcui/third_party/XoFTR/src/datasets/pretrain_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17ab2f3e61299ccc13f5926b0716417bf28a5fad --- /dev/null +++ b/imcui/third_party/XoFTR/src/datasets/pretrain_dataset.py @@ -0,0 +1,156 @@ +import os +import glob +import os.path as osp +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from loguru import logger +import random +from src.utils.dataset import read_pretrain_gray + +class PretrainDataset(Dataset): + def __init__(self, + root_dir, + mode='train', + img_resize=None, + df=None, + img_padding=False, + frame_gap=2, + **kwargs): + """ + Manage image pairs of KAIST Multispectral Pedestrian Detection Benchmark Dataset. + + Args: + root_dir (str): KAIST Multispectral Pedestrian root directory that has `phoenix`. + mode (str): options are ['train', 'val'] + img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. + This is useful during training with batches and testing with memory intensive algorithms. + df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. + img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. + augment_fn (callable, optional): augments images with pre-defined visual effects. + """ + super().__init__() + self.root_dir = root_dir + self.mode = mode + + # specify which part of the data is used for trainng and testing + if mode == 'train': + assert img_resize is not None and img_padding + self.start_ratio = 0.0 + self.end_ratio = 0.9 + elif mode == 'val': + assert img_resize is not None and img_padding + self.start_ratio = 0.9 + self.end_ratio = 1.0 + else: + raise NotImplementedError() + + # parameters for image resizing, padding + self.img_resize = img_resize + self.df = df + self.img_padding = img_padding + + # for training XoFTR + self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + + self.pair_paths = self.generate_kaist_pairs(root_dir, frame_gap=frame_gap, second_frame_range=0) + + def get_kaist_image_paths(self, root_dir): + vis_img_paths = [] + lwir_img_paths = [] + img_num_per_folder = [] + + # Recursively search for folders named "image" + for folder, subfolders, filenames in os.walk(root_dir): + if "visible" in subfolders and "lwir" in subfolders: + vis_img_folder = osp.join(folder, "visible") + lwir_img_folder = osp.join(folder, "lwir") + # Use glob to find image files (you can add more extensions if needed) + vis_imgs_i = glob.glob(osp.join(vis_img_folder, '*.jpg')) + vis_imgs_i.sort() + lwir_imgs_i = glob.glob(osp.join(lwir_img_folder, '*.jpg')) + lwir_imgs_i.sort() + vis_img_paths.append(vis_imgs_i) + lwir_img_paths.append(lwir_imgs_i) + img_num_per_folder.append(len(vis_imgs_i)) + assert len(vis_imgs_i) == len(lwir_imgs_i), f"Image numbers do not match in {folder}, {len(vis_imgs_i)} != {len(lwir_imgs_i)}" + # Add more image file extensions as necessary + return vis_img_paths, lwir_img_paths, img_num_per_folder + + def generate_kaist_pairs(self, root_dir, frame_gap, second_frame_range): + """ generate image pairs (Vis-TIR) from KAIST Pedestrian dataset + Args: + root_dir: root directory for the dataset + frame_gap (int): the frame gap between consecutive images + second_frame_range (int): the range for second image i.e. for the first ind i, second ind j element of [i-10, i+10] + Returns: + pair_paths (list) + """ + vis_img_paths, lwir_img_paths, img_num_per_folder = self.get_kaist_image_paths(root_dir) + pair_paths = [] + for i in range(len(img_num_per_folder)): + num_img = img_num_per_folder[i] + inds_vis = torch.arange(int(self.start_ratio * num_img), + int(self.end_ratio * num_img), + frame_gap, dtype=int) + if second_frame_range > 0: + inds_lwir = inds_vis + torch.randint(-second_frame_range, second_frame_range, (inds_vis.shape[0],)) + inds_lwir[inds_lwirint(self.end_ratio * num_img)-1] = int(self.end_ratio * num_img)-1 + else: + inds_lwir = inds_vis + for j, k in zip(inds_vis, inds_lwir): + img_name0 = os.path.relpath(vis_img_paths[i][j], root_dir) + img_name1 = os.path.relpath(lwir_img_paths[i][k], root_dir) + + if torch.rand(1) > 0.5: + img_name0, img_name1 = img_name1, img_name0 + + pair_paths.append([img_name0, img_name1]) + + random.shuffle(pair_paths) + return pair_paths + + def __len__(self): + return len(self.pair_paths) + + def __getitem__(self, idx): + # read grayscale and normalized image, and mask. (1, h, w) and (h, w) + img_name0 = osp.join(self.root_dir, self.pair_paths[idx][0]) + img_name1 = osp.join(self.root_dir, self.pair_paths[idx][1]) + + if self.mode == "train" and torch.rand(1) > 0.5: + img_name0, img_name1 = img_name1, img_name0 + + image0, image0_norm, mask0, scale0, image0_mean, image0_std = read_pretrain_gray( + img_name0, self.img_resize, self.df, self.img_padding, None) + image1, image1_norm, mask1, scale1, image1_mean, image1_std = read_pretrain_gray( + img_name1, self.img_resize, self.df, self.img_padding, None) + + data = { + 'image0': image0, # (1, h, w) + 'image1': image1, + 'image0_norm': image0_norm, + 'image1_norm': image1_norm, + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + "image0_mean": image0_mean, + "image0_std": image0_std, + "image1_mean": image1_mean, + "image1_std": image1_std, + 'dataset_name': 'PreTrain', + 'pair_id': idx, + 'pair_names': (self.pair_paths[idx][0], self.pair_paths[idx][1]), + } + + # for XoFTR training + if mask0 is not None: # img_padding is True + if self.coarse_scale: + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/XoFTR/src/datasets/sampler.py b/imcui/third_party/XoFTR/src/datasets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61 --- /dev/null +++ b/imcui/third_party/XoFTR/src/datasets/sampler.py @@ -0,0 +1,77 @@ +import torch +from torch.utils.data import Sampler, ConcatDataset + + +class RandomConcatSampler(Sampler): + """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset + in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. + However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. + + For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. + Args: + shuffle (bool): shuffle the random sampled indices across all sub-datsets. + repeat (int): repeatedly use the sampled indices multiple times for training. + [arXiv:1902.05509, arXiv:1901.09335] + NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) + NOTE: This sampler behaves differently with DistributedSampler. + It assume the dataset is splitted across ranks instead of replicated. + TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. + ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 + """ + def __init__(self, + data_source: ConcatDataset, + n_samples_per_subset: int, + subset_replacement: bool=True, + shuffle: bool=True, + repeat: int=1, + seed: int=None): + if not isinstance(data_source, ConcatDataset): + raise TypeError("data_source should be torch.utils.data.ConcatDataset") + + self.data_source = data_source + self.n_subset = len(self.data_source.datasets) + self.n_samples_per_subset = n_samples_per_subset + self.n_samples = self.n_subset * self.n_samples_per_subset * repeat + self.subset_replacement = subset_replacement + self.repeat = repeat + self.shuffle = shuffle + self.generator = torch.manual_seed(seed) + assert self.repeat >= 1 + + def __len__(self): + return self.n_samples + + def __iter__(self): + indices = [] + # sample from each sub-dataset + for d_idx in range(self.n_subset): + low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] + high = self.data_source.cumulative_sizes[d_idx] + if self.subset_replacement: + rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), + generator=self.generator, dtype=torch.int64) + else: # sample without replacement + len_subset = len(self.data_source.datasets[d_idx]) + rand_tensor = torch.randperm(len_subset, generator=self.generator) + low + if len_subset >= self.n_samples_per_subset: + rand_tensor = rand_tensor[:self.n_samples_per_subset] + else: # padding with replacement + rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), + generator=self.generator, dtype=torch.int64) + rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) + indices.append(rand_tensor) + indices = torch.cat(indices) + if self.shuffle: # shuffle the sampled dataset (from multiple subsets) + rand_tensor = torch.randperm(len(indices), generator=self.generator) + indices = indices[rand_tensor] + + # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) + if self.repeat > 1: + repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] + if self.shuffle: + _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] + repeat_indices = map(_choice, repeat_indices) + indices = torch.cat([indices, *repeat_indices], 0) + + assert indices.shape[0] == self.n_samples + return iter(indices.tolist()) diff --git a/imcui/third_party/XoFTR/src/datasets/scannet.py b/imcui/third_party/XoFTR/src/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..a8cfa8d5a91bf275733b980fbf77641d77b16a9b --- /dev/null +++ b/imcui/third_party/XoFTR/src/datasets/scannet.py @@ -0,0 +1,114 @@ +from os import path as osp +from typing import Dict +from unicodedata import name + +import numpy as np +import torch +import torch.utils as utils +from numpy.linalg import inv +from src.utils.dataset import ( + read_scannet_gray, + read_scannet_depth, + read_scannet_pose, + read_scannet_intrinsic +) + + +class ScanNetDataset(utils.data.Dataset): + def __init__(self, + root_dir, + npz_path, + intrinsic_path, + mode='train', + min_overlap_score=0.4, + augment_fn=None, + pose_dir=None, + **kwargs): + """Manage one scene of ScanNet Dataset. + Args: + root_dir (str): ScanNet root directory that contains scene folders. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + intrinsic_path (str): path to depth-camera intrinsic file. + mode (str): options are ['train', 'val', 'test']. + augment_fn (callable, optional): augments images with pre-defined visual effects. + pose_dir (str): ScanNet root directory that contains all poses. + (we use a separate (optional) pose_dir since we store images and poses separately.) + """ + super().__init__() + self.root_dir = root_dir + self.pose_dir = pose_dir if pose_dir is not None else root_dir + self.mode = mode + + # prepare data_names, intrinsics and extrinsics(T) + with np.load(npz_path) as data: + self.data_names = data['name'] + if 'score' in data.keys() and mode not in ['val' or 'test']: + kept_mask = data['score'] > min_overlap_score + self.data_names = self.data_names[kept_mask] + self.intrinsics = dict(np.load(intrinsic_path)) + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + + def __len__(self): + return len(self.data_names) + + def _read_abs_pose(self, scene_name, name): + pth = osp.join(self.pose_dir, + scene_name, + 'pose', f'{name}.txt') + return read_scannet_pose(pth) + + def _compute_rel_pose(self, scene_name, name0, name1): + pose0 = self._read_abs_pose(scene_name, name0) + pose1 = self._read_abs_pose(scene_name, name1) + + return np.matmul(pose1, inv(pose0)) # (4, 4) + + def __getitem__(self, idx): + data_name = self.data_names[idx] + scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the grayscale image which will be resized to (1, 480, 640) + img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') + img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') + + # TODO: Support augmentation & handle seeds for each worker correctly. + image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None) + # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) + + # read the depthmap which is stored as (480, 640) + if self.mode in ['train', 'val']: + depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) + depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) + else: + depth0 = depth1 = torch.tensor([]) + + # read the intrinsic of depthmap + K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), + dtype=torch.float32) + T_1to0 = T_0to1.inverse() + + data = { + 'image0': image0, # (1, h, w) + 'depth0': depth0, # (h, w) + 'image1': image1, + 'depth1': depth1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'dataset_name': 'ScanNet', + 'scene_id': scene_name, + 'pair_id': idx, + 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), + osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) + } + + return data diff --git a/imcui/third_party/XoFTR/src/datasets/vistir.py b/imcui/third_party/XoFTR/src/datasets/vistir.py new file mode 100644 index 0000000000000000000000000000000000000000..6f09e87bac1e7470ab77830d66c632dd727b9a12 --- /dev/null +++ b/imcui/third_party/XoFTR/src/datasets/vistir.py @@ -0,0 +1,109 @@ +import os.path as osp +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from loguru import logger + +from src.utils.dataset import read_vistir_gray + +class VisTirDataset(Dataset): + def __init__(self, + root_dir, + npz_path, + mode='val', + img_resize=None, + df=None, + img_padding=False, + **kwargs): + """ + Manage one scene(npz_path) of VisTir dataset. + + Args: + root_dir (str): VisTIR root directory. + npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. + mode (str): options are ['val', 'test'] + img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. + df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. + img_padding (bool): If set to 'True', zero-pad the image to squared size. + """ + super().__init__() + self.root_dir = root_dir + self.mode = mode + self.scene_id = npz_path.split('.')[0] + + # prepare scene_info and pair_info + self.scene_info = dict(np.load(npz_path, allow_pickle=True)) + self.pair_infos = self.scene_info['pair_infos'].copy() + del self.scene_info['pair_infos'] + + # parameters for image resizing, padding + self.img_resize = img_resize + self.df = df + self.img_padding = img_padding + + # for training XoFTR + self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) + + + def __len__(self): + return len(self.pair_infos) + + def __getitem__(self, idx): + (idx0, idx1) = self.pair_infos[idx] + + + img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0][0]) + img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1][1]) + + # read intrinsics of original size + K_0 = np.array(self.scene_info['intrinsics'][idx0][0], dtype=float).reshape(3,3) + K_1 = np.array(self.scene_info['intrinsics'][idx1][1], dtype=float).reshape(3,3) + + # read distortion coefficients + dist0 = np.array(self.scene_info['distortion_coefs'][idx0][0], dtype=float) + dist1 = np.array(self.scene_info['distortion_coefs'][idx1][1], dtype=float) + + # read grayscale undistorted image and mask. (1, h, w) and (h, w) + image0, mask0, scale0, K_0 = read_vistir_gray( + img_name0, K_0, dist0, self.img_resize, self.df, self.img_padding, augment_fn=None) + image1, mask1, scale1, K_1 = read_vistir_gray( + img_name1, K_1, dist1, self.img_resize, self.df, self.img_padding, augment_fn=None) + + # to tensor + K_0 = torch.tensor(K_0.copy(), dtype=torch.float).reshape(3, 3) + K_1 = torch.tensor(K_1.copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T0 = self.scene_info['poses'][idx0] + T1 = self.scene_info['poses'][idx1] + T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) + T_1to0 = T_0to1.inverse() + + data = { + 'image0': image0, # (1, h, w) + 'image1': image1, + 'T_0to1': T_0to1, # (4, 4) + 'T_1to0': T_1to0, + 'K0': K_0, # (3, 3) + 'K1': K_1, + 'dist0': dist0, + 'dist1': dist1, + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'VisTir', + 'scene_id': self.scene_id, + 'pair_id': idx, + 'pair_names': (self.scene_info['image_paths'][idx0][0], self.scene_info['image_paths'][idx1][1]), + } + + # for XoFTR training + if mask0 is not None: # img_padding is True + if self.coarse_scale: + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/XoFTR/src/lightning/data.py b/imcui/third_party/XoFTR/src/lightning/data.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ad42b5ed304dcd72b82dc88d0314864fbb75c9 --- /dev/null +++ b/imcui/third_party/XoFTR/src/lightning/data.py @@ -0,0 +1,346 @@ +import os +import math +from collections import abc +from loguru import logger +from torch.utils.data.dataset import Dataset +from tqdm import tqdm +from os import path as osp +from pathlib import Path +from joblib import Parallel, delayed + +import pytorch_lightning as pl +from torch import distributed as dist +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset, + DistributedSampler, + RandomSampler, + dataloader +) + +from src.utils.augment import build_augmentor +from src.utils.dataloader import get_local_split +from src.utils.misc import tqdm_joblib +from src.utils import comm +from src.datasets.megadepth import MegaDepthDataset +from src.datasets.vistir import VisTirDataset +from src.datasets.scannet import ScanNetDataset +from src.datasets.sampler import RandomConcatSampler + + +class MultiSceneDataModule(pl.LightningDataModule): + """ + For distributed training, each training process is assgined + only a part of the training scenes to reduce memory overhead. + """ + def __init__(self, args, config): + super().__init__() + + # 1. data config + # Train and Val should from the same data source + self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE + self.val_data_source = config.DATASET.VAL_DATA_SOURCE + self.test_data_source = config.DATASET.TEST_DATA_SOURCE + # training and validating + self.train_data_root = config.DATASET.TRAIN_DATA_ROOT + self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional) + self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT + self.train_list_path = config.DATASET.TRAIN_LIST_PATH + self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH + self.val_data_root = config.DATASET.VAL_DATA_ROOT + self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional) + self.val_npz_root = config.DATASET.VAL_NPZ_ROOT + self.val_list_path = config.DATASET.VAL_LIST_PATH + self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH + # testing + self.test_data_root = config.DATASET.TEST_DATA_ROOT + self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) + self.test_npz_root = config.DATASET.TEST_NPZ_ROOT + self.test_list_path = config.DATASET.TEST_LIST_PATH + self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH + + # 2. dataset config + # general options + self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score + self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN + self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] + + # MegaDepth options + self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 + self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True + self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True + self.mgdpt_df = config.DATASET.MGDPT_DF # 8 + self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr. + + # VisTir options + self.vistir_img_resize = config.DATASET.VISTIR_IMG_RESIZE + self.vistir_img_pad = config.DATASET.VISTIR_IMG_PAD + self.vistir_df = config.DATASET.VISTIR_DF # 8 + + # 3.loader parameters + self.train_loader_params = { + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.val_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.test_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': True + } + + # 4. sampler + self.data_sampler = config.TRAINER.DATA_SAMPLER + self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET + self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT + self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE + self.repeat = config.TRAINER.SB_REPEAT + + # (optional) RandomSampler for debugging + + # misc configurations + self.parallel_load_data = getattr(args, 'parallel_load_data', False) + self.seed = config.TRAINER.SEED # 66 + + def setup(self, stage=None): + """ + Setup train / val / test dataset. This method will be called by PL automatically. + Args: + stage (str): 'fit' in training phase, and 'test' in testing phase. + """ + + assert stage in ['fit', 'test'], "stage must be either fit or test" + + try: + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") + except AssertionError as ae: + self.world_size = 1 + self.rank = 0 + logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") + + if stage == 'fit': + self.train_dataset = self._setup_dataset( + self.train_data_root, + self.train_npz_root, + self.train_list_path, + self.train_intrinsic_path, + mode='train', + min_overlap_score=self.min_overlap_score_train, + pose_dir=self.train_pose_root) + # setup multiple (optional) validation subsets + if isinstance(self.val_list_path, (list, tuple)): + self.val_dataset = [] + if not isinstance(self.val_npz_root, (list, tuple)): + self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] + for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): + self.val_dataset.append(self._setup_dataset( + self.val_data_root, + npz_root, + npz_list, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root)) + else: + self.val_dataset = self._setup_dataset( + self.val_data_root, + self.val_npz_root, + self.val_list_path, + self.val_intrinsic_path, + mode='val', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.val_pose_root) + logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') + else: # stage == 'test + self.test_dataset = self._setup_dataset( + self.test_data_root, + self.test_npz_root, + self.test_list_path, + self.test_intrinsic_path, + mode='test', + min_overlap_score=self.min_overlap_score_test, + pose_dir=self.test_pose_root) + logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') + + def _setup_dataset(self, + data_root, + split_npz_root, + scene_list_path, + intri_path, + mode='train', + min_overlap_score=0., + pose_dir=None): + """ Setup train / val / test set""" + with open(scene_list_path, 'r') as f: + npz_names = [name.split()[0] for name in f.readlines()] + + if mode == 'train': + local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) + else: + local_npz_names = npz_names + logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') + + dataset_builder = self._build_concat_dataset_parallel \ + if self.parallel_load_data \ + else self._build_concat_dataset + return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, + mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) + + def _build_concat_dataset( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None + ): + datasets = [] + augment_fn = self.augment_fn + if mode == 'train': + data_source = self.train_data_source + elif mode == 'val': + data_source = self.val_data_source + else: + data_source = self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + for npz_name in tqdm(npz_names, + desc=f'[rank:{self.rank}] loading {mode} datasets', + disable=int(self.rank) != 0): + # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. + npz_path = osp.join(npz_dir, npz_name) + if data_source == 'ScanNet': + datasets.append( + ScanNetDataset(data_root, + npz_path, + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir)) + elif data_source == 'MegaDepth': + datasets.append( + MegaDepthDataset(data_root, + npz_path, + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale)) + elif data_source == 'VisTir': + datasets.append( + VisTirDataset(data_root, + npz_path, + mode=mode, + img_resize=self.vistir_img_resize, + df=self.vistir_df, + img_padding=self.vistir_img_pad, + coarse_scale=self.coarse_scale)) + else: + raise NotImplementedError() + return ConcatDataset(datasets) + + def _build_concat_dataset_parallel( + self, + data_root, + npz_names, + npz_dir, + intrinsic_path, + mode, + min_overlap_score=0., + pose_dir=None, + ): + augment_fn = self.augment_fn + if mode == 'train': + data_source = self.train_data_source + elif mode == 'val': + data_source = self.val_data_source + else: + data_source = self.test_data_source + if str(data_source).lower() == 'megadepth': + npz_names = [f'{n}.npz' for n in npz_names] + with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', + total=len(npz_names), disable=int(self.rank) != 0)): + if data_source == 'ScanNet': + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + ScanNetDataset, + data_root, + osp.join(npz_dir, x), + intrinsic_path, + mode=mode, + min_overlap_score=min_overlap_score, + augment_fn=augment_fn, + pose_dir=pose_dir))(name) + for name in npz_names) + elif data_source == 'MegaDepth': + # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. + raise NotImplementedError() + datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( + delayed(lambda x: _build_dataset( + MegaDepthDataset, + data_root, + osp.join(npz_dir, x), + mode=mode, + min_overlap_score=min_overlap_score, + img_resize=self.mgdpt_img_resize, + df=self.mgdpt_df, + img_padding=self.mgdpt_img_pad, + depth_padding=self.mgdpt_depth_pad, + augment_fn=augment_fn, + coarse_scale=self.coarse_scale))(name) + for name in npz_names) + else: + raise ValueError(f'Unknown dataset: {data_source}') + return ConcatDataset(datasets) + + def train_dataloader(self): + """ Build training dataloader for ScanNet / MegaDepth. """ + assert self.data_sampler in ['scene_balance'] + logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') + if self.data_sampler == 'scene_balance': + sampler = RandomConcatSampler(self.train_dataset, + self.n_samples_per_subset, + self.subset_replacement, + self.shuffle, self.repeat, self.seed) + else: + sampler = None + dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) + return dataloader + + def val_dataloader(self): + """ Build validation dataloader for ScanNet / MegaDepth. """ + logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') + if not isinstance(self.val_dataset, abc.Sequence): + sampler = DistributedSampler(self.val_dataset, shuffle=False) + return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) + else: + dataloaders = [] + for dataset in self.val_dataset: + sampler = DistributedSampler(dataset, shuffle=False) + dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) + return dataloaders + + def test_dataloader(self, *args, **kwargs): + logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') + sampler = DistributedSampler(self.test_dataset, shuffle=False) + return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) + + +def _build_dataset(dataset: Dataset, *args, **kwargs): + return dataset(*args, **kwargs) diff --git a/imcui/third_party/XoFTR/src/lightning/data_pretrain.py b/imcui/third_party/XoFTR/src/lightning/data_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..11eba2cc3f7918563b3e3ed2f371eb72edf422d0 --- /dev/null +++ b/imcui/third_party/XoFTR/src/lightning/data_pretrain.py @@ -0,0 +1,125 @@ +from collections import abc +from loguru import logger + +import pytorch_lightning as pl +from torch import distributed as dist +from torch.utils.data import ( + DataLoader, + ConcatDataset, + DistributedSampler +) + +from src.datasets.pretrain_dataset import PretrainDataset + + +class PretrainDataModule(pl.LightningDataModule): + """ + For distributed training, each training process is assgined + only a part of the training scenes to reduce memory overhead. + """ + def __init__(self, args, config): + super().__init__() + + # 1. data config + # Train and Val should from the same data source + self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE + self.val_data_source = config.DATASET.VAL_DATA_SOURCE + # training and validating + self.train_data_root = config.DATASET.TRAIN_DATA_ROOT + self.val_data_root = config.DATASET.VAL_DATA_ROOT + + # 2. dataset config'] + + # dataset options + self.pretrain_img_resize = config.DATASET.PRETRAIN_IMG_RESIZE # 840 + self.pretrain_img_pad = config.DATASET.PRETRAIN_IMG_PAD # True + self.pretrain_df = config.DATASET.PRETRAIN_DF # 8 + self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr. + self.frame_gap = config.DATASET.PRETRAIN_FRAME_GAP + + # 3.loader parameters + self.train_loader_params = { + 'batch_size': args.batch_size, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + self.val_loader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': args.num_workers, + 'pin_memory': getattr(args, 'pin_memory', True) + } + + def setup(self, stage=None): + """ + Setup train / val / test dataset. This method will be called by PL automatically. + Args: + stage (str): 'fit' in training phase, and 'test' in testing phase. + """ + + assert stage in ['fit', 'test'], "stage must be either fit or test" + + try: + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") + except AssertionError as ae: + self.world_size = 1 + self.rank = 0 + logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") + + if stage == 'fit': + self.train_dataset = self._setup_dataset( + self.train_data_root, + mode='train') + # setup multiple (optional) validation subsets + self.val_dataset = [] + self.val_dataset.append(self._setup_dataset( + self.val_data_root, + mode='val')) + logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') + else: # stage == 'test + raise ValueError(f"only 'fit' implemented") + + def _setup_dataset(self, + data_root, + mode='train'): + """ Setup train / val / test set""" + + dataset_builder = self._build_concat_dataset + return dataset_builder(data_root, mode=mode) + + def _build_concat_dataset( + self, + data_root, + mode + ): + datasets = [] + + datasets.append( + PretrainDataset(data_root, + mode=mode, + img_resize=self.pretrain_img_resize, + df=self.pretrain_df, + img_padding=self.pretrain_img_pad, + coarse_scale=self.coarse_scale, + frame_gap=self.frame_gap)) + + return ConcatDataset(datasets) + + def train_dataloader(self): + """ Build training dataloader for KAIST dataset. """ + sampler = DistributedSampler(self.train_dataset, shuffle=True) + dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) + return dataloader + + def val_dataloader(self): + """ Build validation dataloader KAIST dataset. """ + if not isinstance(self.val_dataset, abc.Sequence): + return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) + else: + dataloaders = [] + for dataset in self.val_dataset: + sampler = DistributedSampler(dataset, shuffle=False) + dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) + return dataloaders diff --git a/imcui/third_party/XoFTR/src/lightning/lightning_xoftr.py b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr.py new file mode 100644 index 0000000000000000000000000000000000000000..16e6758330a4ac33ea2bbe794d7141756abd61e3 --- /dev/null +++ b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr.py @@ -0,0 +1,334 @@ + +from collections import defaultdict +import pprint +from loguru import logger +from pathlib import Path + +import torch +import numpy as np +import pytorch_lightning as pl +from matplotlib import pyplot as plt +plt.switch_backend('agg') + +from src.xoftr import XoFTR +from src.xoftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine +from src.losses.xoftr_loss import XoFTRLoss +from src.optimizers import build_optimizer, build_scheduler +from src.utils.metrics import ( + compute_symmetrical_epipolar_errors, + compute_pose_errors, + aggregate_metrics +) +from src.utils.plotting import make_matching_figures +from src.utils.comm import gather, all_gather +from src.utils.misc import lower_config, flattenList +from src.utils.profiler import PassThroughProfiler + + +class PL_XoFTR(pl.LightningModule): + def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): + """ + TODO: + - use the new version of PL logging API. + """ + super().__init__() + # Misc + self.config = config # full config + _config = lower_config(self.config) + self.xoftr_cfg = lower_config(_config['xoftr']) + self.profiler = profiler or PassThroughProfiler() + self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + + # Matcher: XoFTR + self.matcher = XoFTR(config=_config['xoftr']) + self.loss = XoFTRLoss(_config) + + # Pretrained weights + if pretrained_ckpt: + state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] + self.matcher.load_state_dict(state_dict, strict=False) + logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") + for name, param in self.matcher.named_parameters(): + if name in state_dict.keys(): + print("in ckpt: ", name) + else: + print("out ckpt: ", name) + + # Testing + self.dump_dir = dump_dir + + def configure_optimizers(self): + # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` + optimizer = build_optimizer(self, self.config) + scheduler = build_scheduler(self.config, optimizer) + return [optimizer], [scheduler] + + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # learning rate warm up + warmup_step = self.config.TRAINER.WARMUP_STEP + if self.trainer.global_step < warmup_step: + if self.config.TRAINER.WARMUP_TYPE == 'linear': + base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR + lr = base_lr + \ + (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ + abs(self.config.TRAINER.TRUE_LR - base_lr) + for pg in optimizer.param_groups: + pg['lr'] = lr + elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pass + else: + raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + + # update params + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() + + def _trainval_inference(self, batch): + with self.profiler.profile("Compute coarse supervision"): + compute_supervision_coarse(batch, self.config) + + with self.profiler.profile("XoFTR"): + self.matcher(batch) + + with self.profiler.profile("Compute fine supervision"): + compute_supervision_fine(batch, self.config) + + with self.profiler.profile("Compute losses"): + self.loss(batch) + + def _compute_metrics(self, batch): + with self.profiler.profile("Copmute metrics"): + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair + + rel_pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'R_errs': batch['R_errs'], + 't_errs': batch['t_errs'], + 'inliers': batch['inliers']} + if self.config.DATASET.VAL_DATA_SOURCE == "VisTir": + metrics.update({'scene_id': batch['scene_id']}) + ret_dict = {'metrics': metrics} + return ret_dict, rel_pair_names + + def training_step(self, batch, batch_idx): + self._trainval_inference(batch) + + # logging + if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + # scalars + for k, v in batch['loss_scalars'].items(): + self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step) + if self.config.TRAINER.USE_WANDB: + self.logger[1].log_metrics({f'train/{k}': v}, self.global_step) + + # figures + if self.config.TRAINER.ENABLE_PLOTTING: + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + for k, v in figures.items(): + self.logger[0].experiment.add_figure(f'train_match/{k}', v, self.global_step) + + return {'loss': batch['loss']} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + if self.trainer.global_rank == 0: + self.logger[0].experiment.add_scalar( + 'train/avg_loss_on_epoch', avg_loss, + global_step=self.current_epoch) + if self.config.TRAINER.USE_WANDB: + self.logger[1].log_metrics( + {'train/avg_loss_on_epoch': avg_loss}, + self.current_epoch) + + def validation_step(self, batch, batch_idx): + # no loss calculation for VisTir during val + if self.config.DATASET.VAL_DATA_SOURCE == "VisTir": + with self.profiler.profile("XoFTR"): + self.matcher(batch) + else: + self._trainval_inference(batch) + + ret_dict, _ = self._compute_metrics(batch) + + val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) + figures = {self.config.TRAINER.PLOT_MODE: []} + if batch_idx % val_plot_interval == 0: + figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE, ret_dict=ret_dict) + if self.config.DATASET.VAL_DATA_SOURCE == "VisTir": + return { + **ret_dict, + 'figures': figures, + } + else: + return { + **ret_dict, + 'loss_scalars': batch['loss_scalars'], + 'figures': figures, + } + + def validation_epoch_end(self, outputs): + # handle multiple validation sets + multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + multi_val_metrics = defaultdict(list) + + for valset_idx, outputs in enumerate(multi_outputs): + # since pl performs sanity_check at the very begining of the training + cur_epoch = self.trainer.current_epoch + if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + cur_epoch = -1 + + if self.config.DATASET.VAL_DATA_SOURCE == "VisTir": + metrics_per_scene = {} + for o in outputs: + if not o['metrics']['scene_id'][0] in metrics_per_scene.keys(): + metrics_per_scene[o['metrics']['scene_id'][0]] = [] + metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics']) + + aucs_per_scene = {} + for scene_id in metrics_per_scene.keys(): + # 2. val metrics: dict of list, numpy + _metrics = metrics_per_scene[scene_id] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + aucs_per_scene[scene_id] = val_metrics + + # average the metrics of scenes + # since the number of images in each scene is different + val_metrics_4tb = {} + for thr in [5, 10, 20]: + temp = [] + for scene_id in metrics_per_scene.keys(): + temp.append(aucs_per_scene[scene_id][f'auc@{thr}']) + val_metrics_4tb[f'auc@{thr}'] = float(np.array(temp, dtype=float).mean()) + temp = [] + for scene_id in metrics_per_scene.keys(): + temp.append(aucs_per_scene[scene_id][f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}']) + val_metrics_4tb[f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'] = float(np.array(temp, dtype=float).mean()) + else: + # 1. loss_scalars: dict of list, on cpu + _loss_scalars = [o['loss_scalars'] for o in outputs] + loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + + # 2. val metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + + for thr in [5, 10, 20]: + multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}']) + + # 3. figures + _figures = [o['figures'] for o in outputs] + figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} + + # tensorboard records only on rank 0 + if self.trainer.global_rank == 0: + if self.config.DATASET.VAL_DATA_SOURCE != "VisTir": + for k, v in loss_scalars.items(): + mean_v = torch.stack(v).mean() + self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + + for k, v in val_metrics_4tb.items(): + self.logger[0].experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch) + if self.config.TRAINER.USE_WANDB: + self.logger[1].log_metrics({f"metrics_{valset_idx}/{k}": v}, cur_epoch) + + for k, v in figures.items(): + if self.trainer.global_rank == 0: + for plot_idx, fig in enumerate(v): + self.logger[0].experiment.add_figure( + f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True) + plt.close('all') + + for thr in [5, 10, 20]: + # log on all ranks for ModelCheckpoint callback to work properly + self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this + + def test_step(self, batch, batch_idx): + with self.profiler.profile("XoFTR"): + self.matcher(batch) + + ret_dict, rel_pair_names = self._compute_metrics(batch) + + with self.profiler.profile("dump_results"): + if self.dump_dir is not None: + # dump results for further analysis + keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf_f', 'epi_errs'} + pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].shape[0] + dumps = [] + for b_id in range(bs): + item = {} + mask = batch['m_bids'] == b_id + item['pair_names'] = pair_names[b_id] + item['identifier'] = '#'.join(rel_pair_names[b_id]) + if self.config.DATASET.TEST_DATA_SOURCE == "VisTir": + item['scene_id'] = batch['scene_id'] + item['K0'] = batch['K0'][b_id].cpu().numpy() + item['K1'] = batch['K1'][b_id].cpu().numpy() + item['dist0'] = batch['dist0'][b_id].cpu().numpy() + item['dist1'] = batch['dist1'][b_id].cpu().numpy() + for key in keys_to_save: + item[key] = batch[key][mask].cpu().numpy() + for key in ['R_errs', 't_errs', 'inliers']: + item[key] = batch[key][b_id] + dumps.append(item) + ret_dict['dumps'] = dumps + + return ret_dict + + def test_epoch_end(self, outputs): + + if self.config.DATASET.TEST_DATA_SOURCE == "VisTir": + # metrics: dict of list, numpy + metrics_per_scene = {} + for o in outputs: + if not o['metrics']['scene_id'][0] in metrics_per_scene.keys(): + metrics_per_scene[o['metrics']['scene_id'][0]] = [] + metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics']) + + aucs_per_scene = {} + for scene_id in metrics_per_scene.keys(): + # 2. val metrics: dict of list, numpy + _metrics = metrics_per_scene[scene_id] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + aucs_per_scene[scene_id] = val_metrics + + # average the metrics of scenes + # since the number of images in each scene is different + val_metrics_4tb = {} + for thr in [5, 10, 20]: + temp = [] + for scene_id in metrics_per_scene.keys(): + temp.append(aucs_per_scene[scene_id][f'auc@{thr}']) + val_metrics_4tb[f'auc@{thr}'] = np.array(temp, dtype=float).mean() + else: + # metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + + # [{key: [{...}, *#bs]}, *#batch] + if self.dump_dir is not None: + Path(self.dump_dir).mkdir(parents=True, exist_ok=True) + _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] + dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] + logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') + + if self.trainer.global_rank == 0: + print(self.profiler.summary()) + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) + logger.info('\n' + pprint.pformat(val_metrics_4tb)) + if self.dump_dir is not None: + np.save(Path(self.dump_dir) / 'XoFTR_pred_eval', dumps) \ No newline at end of file diff --git a/imcui/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc57d4205390cd1439edc2f38e3b219bad3db35 --- /dev/null +++ b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py @@ -0,0 +1,171 @@ + +from loguru import logger + +import torch +import pytorch_lightning as pl +from matplotlib import pyplot as plt +plt.switch_backend('agg') + +from src.xoftr import XoFTR_Pretrain +from src.losses.xoftr_loss_pretrain import XoFTRLossPretrain +from src.optimizers import build_optimizer, build_scheduler +from src.utils.plotting import make_mae_figures +from src.utils.comm import all_gather +from src.utils.misc import lower_config, flattenList +from src.utils.profiler import PassThroughProfiler +from src.utils.pretrain_utils import generate_random_masks, get_target + + +class PL_XoFTR_Pretrain(pl.LightningModule): + def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): + """ + TODO: + - use the new version of PL logging API. + """ + super().__init__() + # Misc + self.config = config # full config + + _config = lower_config(self.config) + self.xoftr_cfg = lower_config(_config['xoftr']) + self.profiler = profiler or PassThroughProfiler() + self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + + # generator to create the same masks for validation + self.val_seed = self.config.PRETRAIN.VAL_SEED + self.val_generator = torch.Generator(device="cuda").manual_seed(self.val_seed) + self.mae_margins = config.PRETRAIN.MAE_MARGINS + + # Matcher: XoFTR + self.matcher = XoFTR_Pretrain(config=_config['xoftr']) + self.loss = XoFTRLossPretrain(_config) + + # Pretrained weights + if pretrained_ckpt: + state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] + self.matcher.load_state_dict(state_dict, strict=False) + logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") + + # Testing + self.dump_dir = dump_dir + + def configure_optimizers(self): + # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` + optimizer = build_optimizer(self, self.config) + scheduler = build_scheduler(self.config, optimizer) + return [optimizer], [scheduler] + + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # learning rate warm up + warmup_step = self.config.TRAINER.WARMUP_STEP + if self.trainer.global_step < warmup_step: + if self.config.TRAINER.WARMUP_TYPE == 'linear': + base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR + lr = base_lr + \ + (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ + abs(self.config.TRAINER.TRUE_LR - base_lr) + for pg in optimizer.param_groups: + pg['lr'] = lr + elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pass + else: + raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + + # update params + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() + + def _trainval_inference(self, batch, generator=None): + generate_random_masks(batch, + patch_size=self.config.PRETRAIN.PATCH_SIZE, + mask_ratio=self.config.PRETRAIN.MASK_RATIO, + generator=generator, + margins=self.mae_margins) + + with self.profiler.profile("XoFTR"): + self.matcher(batch) + + with self.profiler.profile("Compute losses"): + # Create target pacthes to reconstruct + get_target(batch) + self.loss(batch) + + def training_step(self, batch, batch_idx): + self._trainval_inference(batch) + + # logging + if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + # scalars + for k, v in batch['loss_scalars'].items(): + self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step) + if self.config.TRAINER.USE_WANDB: + self.logger[1].log_metrics({f'train/{k}': v}, self.global_step) + + if self.config.TRAINER.ENABLE_PLOTTING: + figures = make_mae_figures(batch) + for i, figure in enumerate(figures): + self.logger[0].experiment.add_figure( + f'train_mae/node_{self.trainer.global_rank}-device_{self.device.index}-batch_{i}', + figure, self.global_step) + + return {'loss': batch['loss']} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + if self.trainer.global_rank == 0: + self.logger[0].experiment.add_scalar( + 'train/avg_loss_on_epoch', avg_loss, + global_step=self.current_epoch) + if self.config.TRAINER.USE_WANDB: + self.logger[1].log_metrics( + {'train/avg_loss_on_epoch': avg_loss}, + self.current_epoch) + + def validation_step(self, batch, batch_idx): + self._trainval_inference(batch, self.val_generator) + + val_plot_interval = max(self.trainer.num_val_batches[0] // \ + (self.trainer.num_gpus * self.n_vals_plot), 1) + figures = [] + if batch_idx % val_plot_interval == 0: + figures = make_mae_figures(batch) + + return { + 'loss_scalars': batch['loss_scalars'], + 'figures': figures, + } + + def validation_epoch_end(self, outputs): + self.val_generator.manual_seed(self.val_seed) + # handle multiple validation sets + multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + + for valset_idx, outputs in enumerate(multi_outputs): + # since pl performs sanity_check at the very begining of the training + cur_epoch = self.trainer.current_epoch + if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + cur_epoch = -1 + + # 1. loss_scalars: dict of list, on cpu + _loss_scalars = [o['loss_scalars'] for o in outputs] + loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + + _figures = [o['figures'] for o in outputs] + figures = [item for sublist in _figures for item in sublist] + + # tensorboard records only on rank 0 + if self.trainer.global_rank == 0: + for k, v in loss_scalars.items(): + mean_v = torch.stack(v).mean() + self.logger[0].experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + if self.config.TRAINER.USE_WANDB: + self.logger[1].log_metrics({f'val_{valset_idx}/avg_{k}': mean_v}, cur_epoch) + + for plot_idx, fig in enumerate(figures): + self.logger[0].experiment.add_figure( + f'val_mae_{valset_idx}/pair-{plot_idx}', fig, cur_epoch, close=True) + + plt.close('all') + diff --git a/imcui/third_party/XoFTR/src/losses/xoftr_loss.py b/imcui/third_party/XoFTR/src/losses/xoftr_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7ac9c1546a09fbd909ea11e3e597e03d92d56e --- /dev/null +++ b/imcui/third_party/XoFTR/src/losses/xoftr_loss.py @@ -0,0 +1,170 @@ +from loguru import logger + +import torch +import torch.nn as nn +from kornia.geometry.conversions import convert_points_to_homogeneous +from kornia.geometry.epipolar import numeric + +class XoFTRLoss(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config # config under the global namespace + self.loss_config = config['xoftr']['loss'] + self.pos_w = self.loss_config['pos_weight'] + self.neg_w = self.loss_config['neg_weight'] + + + def compute_fine_matching_loss(self, data): + """ Point-wise Focal Loss with 0 / 1 confidence as gt. + Args: + data (dict): { + conf_matrix_fine (torch.Tensor): (N, W_f^2, W_f^2) + conf_matrix_f_gt (torch.Tensor): (N, W_f^2, W_f^2) + } + """ + conf_matrix_fine = data['conf_matrix_fine'] + conf_matrix_f_gt = data['conf_matrix_f_gt'] + pos_mask, neg_mask = conf_matrix_f_gt > 0, conf_matrix_f_gt == 0 + pos_w, neg_w = self.pos_w, self.neg_w + + if not pos_mask.any(): # assign a wrong gt + pos_mask[0, 0, 0] = True + pos_w = 0. + if not neg_mask.any(): + neg_mask[0, 0, 0] = True + neg_w = 0. + + conf_matrix_fine = torch.clamp(conf_matrix_fine, 1e-6, 1-1e-6) + alpha = self.loss_config['focal_alpha'] + gamma = self.loss_config['focal_gamma'] + + loss_pos = - alpha * torch.pow(1 - conf_matrix_fine[pos_mask], gamma) * (conf_matrix_fine[pos_mask]).log() + # loss_pos *= conf_matrix_f_gt[pos_mask] + loss_neg = - alpha * torch.pow(conf_matrix_fine[neg_mask], gamma) * (1 - conf_matrix_fine[neg_mask]).log() + + return pos_w * loss_pos.mean() + neg_w * loss_neg.mean() + + def _symmetric_epipolar_distance(self, pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (torch.Tensor): [N, 2] + E (torch.Tensor): [3, 3] + """ + pts0 = (pts0 - K0[:, [0, 1], [2, 2]]) / K0[:, [0, 1], [0, 1]] + pts1 = (pts1 - K1[:, [0, 1], [2, 2]]) / K1[:, [0, 1], [0, 1]] + pts0 = convert_points_to_homogeneous(pts0) + pts1 = convert_points_to_homogeneous(pts1) + + Ep0 = (pts0[:,None,:] @ E.transpose(-2,-1)).squeeze(1) # [N, 3] + p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] + Etp1 = (pts1[:,None,:] @ E).squeeze(1) # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2 + 1e-9) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2 + 1e-9)) # N + return d + + def compute_sub_pixel_loss(self, data): + """ symmetric epipolar distance loss. + Args: + data (dict): { + m_bids (torch.Tensor): (N) + T_0to1 (torch.Tensor): (B, 4, 4) + mkpts0_f_train (torch.Tensor): (N, 2) + mkpts1_f_train (torch.Tensor): (N, 2) + } + """ + + Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) + E_mat = Tx @ data['T_0to1'][:, :3, :3] + + m_bids = data['m_bids'] + pts0 = data['mkpts0_f_train'] + pts1 = data['mkpts1_f_train'] + + sym_dist = self._symmetric_epipolar_distance(pts0, pts1, E_mat[m_bids], data['K0'][m_bids], data['K1'][m_bids]) + # filter matches with high epipolar error (only train approximately correct fine-level matches) + loss = sym_dist[sym_dist<1e-4] + if len(loss) == 0: + return torch.zeros(1, device=loss.device, requires_grad=False)[0] + return loss.mean() + + def compute_coarse_loss(self, data, weight=None): + """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt. + Args: + data (dict): { + conf_matrix_0_to_1 (torch.Tensor): (N, HW0, HW1) + conf_matrix_1_to_0 (torch.Tensor): (N, HW0, HW1) + conf_gt (torch.Tensor): (N, HW0, HW1) + } + weight (torch.Tensor): (N, HW0, HW1) + """ + + conf_matrix_0_to_1 = data["conf_matrix_0_to_1"] + conf_matrix_1_to_0 = data["conf_matrix_1_to_0"] + conf_gt = data["conf_matrix_gt"] + + pos_mask = conf_gt == 1 + c_pos_w = self.pos_w + # corner case: no gt coarse-level match at all + if not pos_mask.any(): # assign a wrong gt + pos_mask[0, 0, 0] = True + if weight is not None: + weight[0, 0, 0] = 0. + c_pos_w = 0. + + conf_matrix_0_to_1 = torch.clamp(conf_matrix_0_to_1, 1e-6, 1-1e-6) + conf_matrix_1_to_0 = torch.clamp(conf_matrix_1_to_0, 1e-6, 1-1e-6) + alpha = self.loss_config['focal_alpha'] + gamma = self.loss_config['focal_gamma'] + + loss_pos = - alpha * torch.pow(1 - conf_matrix_0_to_1[pos_mask], gamma) * (conf_matrix_0_to_1[pos_mask]).log() + loss_pos += - alpha * torch.pow(1 - conf_matrix_1_to_0[pos_mask], gamma) * (conf_matrix_1_to_0[pos_mask]).log() + if weight is not None: + loss_pos = loss_pos * weight[pos_mask] + + loss_c = (c_pos_w * loss_pos.mean()) + + return loss_c + + @torch.no_grad() + def compute_c_weight(self, data): + """ compute element-wise weights for computing coarse-level loss. """ + if 'mask0' in data: + c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float() + else: + c_weight = None + return c_weight + + def forward(self, data): + """ + Update: + data (dict): update{ + 'loss': [1] the reduced loss across a batch, + 'loss_scalars' (dict): loss scalars for tensorboard_record + } + """ + loss_scalars = {} + # 0. compute element-wise loss weight + c_weight = self.compute_c_weight(data) + + # 1. coarse-level loss + loss_c = self.compute_coarse_loss(data, weight=c_weight) + loss_c *= self.loss_config['coarse_weight'] + loss = loss_c + loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) + + # 2. fine-level matching loss for windows + loss_f_match = self.compute_fine_matching_loss(data) + loss_f_match *= self.loss_config['fine_weight'] + loss = loss + loss_f_match + loss_scalars.update({"loss_f": loss_f_match.clone().detach().cpu()}) + + # 3. sub-pixel refinement loss + loss_sub = self.compute_sub_pixel_loss(data) + loss_sub *= self.loss_config['sub_weight'] + loss = loss + loss_sub + loss_scalars.update({"loss_sub": loss_sub.clone().detach().cpu()}) + + + loss_scalars.update({'loss': loss.clone().detach().cpu()}) + data.update({"loss": loss, "loss_scalars": loss_scalars}) diff --git a/imcui/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py b/imcui/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..d00e1781b0bd105414b59697418a33b27d178b4e --- /dev/null +++ b/imcui/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class XoFTRLossPretrain(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config # config under the global namespace + self.W_f = config["xoftr"]['fine_window_size'] + + def forward(self, data): + """ + Update: + data (dict): update{ + 'loss': [1] the reduced loss across a batch, + 'loss_scalars' (dict): loss scalars for tensorboard_record + } + """ + loss_scalars = {} + + pred0, pred1 = data["pred0"], data["pred1"] + target0, target1 = data["target0"], data["target1"] + target0 = target0[[data['b_ids'], data['i_ids']]] + target1 = target1[[data['b_ids'], data['j_ids']]] + + # get correct indices + pred0 = pred0[data["ids_image0"]] + pred1 = pred1[data["ids_image1"]] + target0 = target0[data["ids_image0"]] + target1 = target1[data["ids_image1"]] + + loss0 = (pred0 - target0)**2 + loss1 = (pred1 - target1)**2 + loss = loss0.mean() + loss1.mean() + + loss_scalars.update({'loss': loss.clone().detach().cpu()}) + data.update({"loss": loss, "loss_scalars": loss_scalars}) diff --git a/imcui/third_party/XoFTR/src/optimizers/__init__.py b/imcui/third_party/XoFTR/src/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f --- /dev/null +++ b/imcui/third_party/XoFTR/src/optimizers/__init__.py @@ -0,0 +1,42 @@ +import torch +from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR + + +def build_optimizer(model, config): + name = config.TRAINER.OPTIMIZER + lr = config.TRAINER.TRUE_LR + + if name == "adam": + return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) + elif name == "adamw": + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) + else: + raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") + + +def build_scheduler(config, optimizer): + """ + Returns: + scheduler (dict):{ + 'scheduler': lr_scheduler, + 'interval': 'step', # or 'epoch' + 'monitor': 'val_f1', (optional) + 'frequency': x, (optional) + } + """ + scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + name = config.TRAINER.SCHEDULER + + if name == 'MultiStepLR': + scheduler.update( + {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) + elif name == 'CosineAnnealing': + scheduler.update( + {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) + elif name == 'ExponentialLR': + scheduler.update( + {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + else: + raise NotImplementedError() + + return scheduler diff --git a/imcui/third_party/XoFTR/src/utils/augment.py b/imcui/third_party/XoFTR/src/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3ef976b4ad93f5eb3be89707881d1780445313 --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/augment.py @@ -0,0 +1,113 @@ +import albumentations as A +import numpy as np +import cv2 + +class DarkAug(object): + """ + Extreme dark augmentation aiming at Aachen Day-Night + """ + + def __init__(self): + self.augmentor = A.Compose([ + A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) + ], p=0.75) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +class MobileAug(object): + """ + Random augmentations aiming at images of mobile/handhold devices. + """ + + def __init__(self): + self.augmentor = A.Compose([ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25) + ], p=1.0) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + +class RGBThermalAug(object): + """ + Pseudo-thermal image augmentation + """ + + def __init__(self): + self.blur = A.Blur(p=0.7, blur_limit=(2, 4)) + self.hsv = A.HueSaturationValue(p=0.9, val_shift_limit=(-30, +30), hue_shift_limit=(-90,+90), sat_shift_limit=(-30,+30)) + + # Switch images to apply augmentation + self.random_switch = True + + # parameters for the cosine transform + self.w_0 = np.pi * 2 / 3 + self.w_r = np.pi / 2 + self.theta_r = np.pi / 2 + + def augment_pseudo_thermal(self, image): + + # HSV augmentation + image = self.hsv(image=image)["image"] + + # Random blur + image = self.blur(image=image)["image"] + + # Convert the image to the gray scale + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + # Normalize the image between (-0.5, 0.5) + image = image / 255 - 0.5 # 8 bit color + + # Random phase and freq for the cosine transform + phase = np.pi / 2 + np.random.randn(1) * self.theta_r + w = self.w_0 + np.abs(np.random.randn(1)) * self.w_r + + # Cosine transform + image = np.cos(image * w + phase) + + # Min-max normalization for the transformed image + image = (image - image.min()) / (image.max() - image.min()) * 255 + + # 3 channel gray + image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_GRAY2RGB) + + return image + + def __call__(self, x, image_num): + if image_num==0: + # augmentation for RGB image can be added here + return x + elif image_num==1: + # pseudo-thermal augmentation + return self.augment_pseudo_thermal(x) + else: + raise ValueError(f'Invalid image number: {image_num}') + + +def build_augmentor(method=None, **kwargs): + + if method == 'dark': + return DarkAug() + elif method == 'mobile': + return MobileAug() + elif method == "rgb_thermal": + return RGBThermalAug() + elif method is None: + return None + else: + raise ValueError(f'Invalid augmentation method: {method}') + + +if __name__ == '__main__': + augmentor = build_augmentor('FDA') \ No newline at end of file diff --git a/imcui/third_party/XoFTR/src/utils/comm.py b/imcui/third_party/XoFTR/src/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167 --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/comm.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +[Copied from detectron2] +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/imcui/third_party/XoFTR/src/utils/data_io.py b/imcui/third_party/XoFTR/src/utils/data_io.py new file mode 100644 index 0000000000000000000000000000000000000000..c63628c9c7f9fa64a6c8817c0ed9f514f29aa9cb --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/data_io.py @@ -0,0 +1,144 @@ +import torch +from torch import nn +import numpy as np +import cv2 +# import torchvision.transforms as transforms +import torch.nn.functional as F +from yacs.config import CfgNode as CN + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +class DataIOWrapper(nn.Module): + """ + Pre-propcess data from different sources + """ + + def __init__(self, model, config, ckpt=None): + super().__init__() + + self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu') + torch.set_grad_enabled(False) + self.model = model + self.config = config + self.img0_size = config['img0_resize'] + self.img1_size = config['img1_resize'] + self.df = config['df'] + self.padding = config['padding'] + self.coarse_scale = config['coarse_scale'] + + if ckpt: + ckpt_dict = torch.load(ckpt) + self.model.load_state_dict(ckpt_dict['state_dict']) + self.model = self.model.eval().to(self.device) + + def preprocess_image(self, img, device, resize=None, df=None, padding=None, cam_K=None, dist=None, gray_scale=True): + # xoftr takes grayscale input images + if gray_scale and len(img.shape) == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + h, w = img.shape[:2] + new_K = None + img_undistorted = None + if cam_K is not None and dist is not None: + new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h)) + img = cv2.undistort(img, cam_K, dist, None, new_K) + img_undistorted = img.copy() + + if resize is not None: + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + + if df is not None: + w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new]) + + img = cv2.resize(img, (w_new, h_new)) + scale = np.array([w/w_new, h/h_new], dtype=np.float) + if padding: # padding + pad_to = max(h_new, w_new) + img, mask = self.pad_bottom_right(img, pad_to, ret_mask=True) + mask = torch.from_numpy(mask).to(device) + else: + mask = None + # img = transforms.functional.to_tensor(img).unsqueeze(0).to(device) + if len(img.shape) == 2: # grayscale image + img = torch.from_numpy(img)[None][None].cuda().float() / 255.0 + else: # Color image + img = torch.from_numpy(img).permute(2, 0, 1)[None].float() / 255.0 + return img, scale, mask, new_K, img_undistorted + + def from_cv_imgs(self, img0, img1, K0=None, K1=None, dist0=None, dist1=None): + img0_tensor, scale0, mask0, new_K0, img0_undistorted = self.preprocess_image( + img0, self.device, resize=self.img0_size, df=self.df, padding=self.padding, cam_K=K0, dist=dist0) + img1_tensor, scale1, mask1, new_K1, img1_undistorted = self.preprocess_image( + img1, self.device, resize=self.img1_size, df=self.df, padding=self.padding, cam_K=K1, dist=dist1) + mkpts0, mkpts1, mconf = self.match_images(img0_tensor, img1_tensor, mask0, mask1) + mkpts0 = mkpts0 * scale0 + mkpts1 = mkpts1 * scale1 + matches = np.concatenate([mkpts0, mkpts1], axis=1) + data = {'matches':matches, + 'mkpts0':mkpts0, + 'mkpts1':mkpts1, + 'mconf':mconf, + 'img0':img0, + 'img1':img1 + } + if K0 is not None and dist0 is not None: + data.update({'new_K0':new_K0, 'img0_undistorted':img0_undistorted}) + if K1 is not None and dist1 is not None: + data.update({'new_K1':new_K1, 'img1_undistorted':img1_undistorted}) + return data + + def from_paths(self, img0_pth, img1_pth, K0=None, K1=None, dist0=None, dist1=None, read_color=False): + + imread_flag = cv2.IMREAD_COLOR if read_color else cv2.IMREAD_GRAYSCALE + + img0 = cv2.imread(img0_pth, imread_flag) + img1 = cv2.imread(img1_pth, imread_flag) + return self.from_cv_imgs(img0, img1, K0=K0, K1=K1, dist0=dist0, dist1=dist1) + + def match_images(self, image0, image1, mask0, mask1): + batch = {'image0': image0, 'image1': image1} + if mask0 is not None: # img_padding is True + if self.coarse_scale: + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.coarse_scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)}) + self.model(batch) + mkpts0 = batch['mkpts0_f'].cpu().numpy() + mkpts1 = batch['mkpts1_f'].cpu().numpy() + mconf = batch['mconf_f'].cpu().numpy() + return mkpts0, mkpts1, mconf + + def pad_bottom_right(self, inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + elif inp.ndim == 3: + padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) + padded[:, :inp.shape[1], :inp.shape[2]] = inp + if ret_mask: + mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) + mask[:, :inp.shape[1], :inp.shape[2]] = True + else: + raise NotImplementedError() + return padded, mask + diff --git a/imcui/third_party/XoFTR/src/utils/dataloader.py b/imcui/third_party/XoFTR/src/utils/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1 --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/dataloader.py @@ -0,0 +1,23 @@ +import numpy as np + + +# --- PL-DATAMODULE --- + +def get_local_split(items: list, world_size: int, rank: int, seed: int): + """ The local rank only loads a split of the dataset. """ + n_items = len(items) + items_permute = np.random.RandomState(seed).permutation(items) + if n_items % world_size == 0: + padded_items = items_permute + else: + padding = np.random.RandomState(seed).choice( + items, + world_size - (n_items % world_size), + replace=True) + padded_items = np.concatenate([items_permute, padding]) + assert len(padded_items) % world_size == 0, \ + f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' + n_per_rank = len(padded_items) // world_size + local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] + + return local_items diff --git a/imcui/third_party/XoFTR/src/utils/dataset.py b/imcui/third_party/XoFTR/src/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e20de10040a2888fd08279b0c54bd8a4c39665bb --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/dataset.py @@ -0,0 +1,279 @@ +import io +from loguru import logger + +import cv2 +import numpy as np +import h5py +import torch +from numpy.linalg import inv + + +try: + # for internel use only + from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT +except Exception: + MEGADEPTH_CLIENT = SCANNET_CLIENT = None + +# --- DATA IO --- + +def load_array_from_s3( + path, client, cv_type, + use_h5py=False, +): + byte_str = client.Get(path) + try: + if not use_h5py: + raw_array = np.fromstring(byte_str, np.uint8) + data = cv2.imdecode(raw_array, cv_type) + else: + f = io.BytesIO(byte_str) + data = np.array(h5py.File(f, 'r')['/depth']) + except Exception as ex: + print(f"==> Data loading failure: {path}") + raise ex + + assert data is not None + return data + + +def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ + else cv2.IMREAD_COLOR + if str(path).startswith('s3://'): + image = load_array_from_s3(str(path), client, cv_type) + else: + image = cv2.imread(str(path), cv_type) + + if augment_fn is not None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (h, w) + + +def get_resized_wh(w, h, resize=None): + if resize is not None: # resize the longer edge + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new, h_new = map(lambda x: int(x // df * df), [w, h]) + else: + w_new, h_new = w, h + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + elif inp.ndim == 3: + padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) + padded[:, :inp.shape[1], :inp.shape[2]] = inp + if ret_mask: + mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) + mask[:, :inp.shape[1], :inp.shape[2]] = True + else: + raise NotImplementedError() + return padded, mask + + +# --- MEGADEPTH --- + +def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) + + # resize image + w, h = image.shape[1], image.shape[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + mask = torch.from_numpy(mask) + + return image, mask, scale + + +def read_megadepth_depth(path, pad_to=None): + if str(path).startswith('s3://'): + depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) + else: + depth = np.array(h5py.File(path, 'r')['depth']) + if pad_to is not None: + depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +# --- ScanNet --- + +def read_scannet_gray(path, resize=(640, 480), augment_fn=None): + """ + Args: + resize (tuple): align image to depthmap, in (w, h). + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read and resize image + image = imread_gray(path, augment_fn) + image = cv2.resize(image, resize) + + # (h, w) -> (1, h, w) and normalized + image = torch.from_numpy(image).float()[None] / 255 + return image + + +def read_scannet_depth(path): + if str(path).startswith('s3://'): + depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) + else: + depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +def read_scannet_pose(path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = inv(cam2world) + return world2cam + + +def read_scannet_intrinsic(path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return intrinsic[:-1, :-1] + + +# --- VisTir --- + +def read_vistir_gray(path, cam_K, dist, resize=None, df=None, padding=False, augment_fn=None): + """ + Args: + cam_K (3, 3): camera matrix + dist (8): distortion coefficients + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + image = imread_gray(path, augment_fn, client=None) + + h, w = image.shape[:2] + # update camera matrix + new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h)) + # undistort image + image = cv2.undistort(image, cam_K, dist, None, new_K) + + # resize image + w, h = image.shape[1], image.shape[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + mask = torch.from_numpy(mask) + else: + mask = None + + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + + return image, mask, scale, new_K + +# --- PRETRAIN --- + +def read_pretrain_gray(path, resize=None, df=None, padding=False, augment_fn=None): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) gray scale image + image_norm (torch.tensor): (1, h, w) normalized image + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + image_mean (torch.tensor): (1, 1, 1, 1) + image_std (torch.tensor): (1, 1, 1, 1) + """ + # read image + image = imread_gray(path, augment_fn, client=None) + + # resize image + w, h = image.shape[1], image.shape[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + image = image.astype(np.float32) / 255 + + image_mean = image.mean() + image_std = image.std() + image_norm = (image - image_mean) / (image_std + 1e-6) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + image_norm, _ = pad_bottom_right(image_norm, pad_to, ret_mask=False) + mask = torch.from_numpy(mask) + else: + mask = None + + image_mean = torch.as_tensor(image_mean).float()[None,None,None] + image_std = torch.as_tensor(image_std).float()[None,None,None] + + image = torch.from_numpy(image).float()[None] + image_norm = torch.from_numpy(image_norm).float()[None] # (h, w) -> (1, h, w) and normalized + + return image, image_norm, mask, scale, image_mean, image_std + diff --git a/imcui/third_party/XoFTR/src/utils/metrics.py b/imcui/third_party/XoFTR/src/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a708e17a828574686d81d0d591dfa0c7597fa387 --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/metrics.py @@ -0,0 +1,211 @@ +import torch +import cv2 +import numpy as np +from collections import OrderedDict +from loguru import logger +from kornia.geometry.epipolar import numeric +from kornia.geometry.conversions import convert_points_to_homogeneous + + +# --- METRICS --- + +def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): + # angle error between 2 vectors + t_gt = T_0to1[:3, 3] + n = np.linalg.norm(t) * np.linalg.norm(t_gt) + t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) + t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity + if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging + t_err = 0 + + # angle error between 2 rotation matrices + R_gt = T_0to1[:3, :3] + cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 + cos = np.clip(cos, -1., 1.) # handle numercial errors + R_err = np.rad2deg(np.abs(np.arccos(cos))) + + return t_err, R_err + + +def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (torch.Tensor): [N, 2] + E (torch.Tensor): [3, 3] + """ + pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + pts0 = convert_points_to_homogeneous(pts0) + pts1 = convert_points_to_homogeneous(pts1) + + Ep0 = pts0 @ E.T # [N, 3] + p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] + Etp1 = pts1 @ E # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + +def symmetric_epipolar_distance_numpy(pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (numpy.array): [N, 2] + E (numpy.array): [3, 3] + """ + pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + pts0 = np.hstack((pts0, np.ones((pts0.shape[0], 1)))) + pts1 = np.hstack((pts1, np.ones((pts1.shape[0], 1)))) + + Ep0 = pts0 @ E.T # [N, 3] + p1Ep0 = np.sum(pts1 * Ep0, -1) # [N,] + Etp1 = pts1 @ E # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + +def compute_symmetrical_epipolar_errors(data): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) + E_mat = Tx @ data['T_0to1'][:, :3, :3] + + m_bids = data['m_bids'] + pts0 = data['mkpts0_f'] + pts1 = data['mkpts1_f'] + + epi_errs = [] + for bs in range(Tx.size(0)): + mask = m_bids == bs + epi_errs.append( + symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + epi_errs = torch.cat(epi_errs, dim=0) + + data.update({'epi_errs': epi_errs}) + + +def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + if len(kpts0) < 5: + return None + # normalize keypoints + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + # normalize ransac threshold + ransac_thr = thresh / np.mean([K0[0, 0], K0[1, 1], K1[0, 0], K1[1, 1]]) + + # compute pose with cv2 + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) + if E is None: + print("\nE is None while trying to recover pose.\n") + return None + + # recover pose from E + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + ret = (R, t[:, 0], mask.ravel() > 0) + best_num_inliers = n + + return ret + + +def compute_pose_errors(data, config): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] + "t_errs" List[float]: [N] + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 + conf = config.TRAINER.RANSAC_CONF # 0.99999 + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + K0 = data['K0'].cpu().numpy() + K1 = data['K1'].cpu().numpy() + T_0to1 = data['T_0to1'].cpu().numpy() + + for bs in range(K0.shape[0]): + mask = m_bids == bs + ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) + + if ret is None: + data['R_errs'].append(np.inf) + data['t_errs'].append(np.inf) + data['inliers'].append(np.array([]).astype(np.bool)) + else: + R, t, inliers = ret + t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) + data['R_errs'].append(R_err) + data['t_errs'].append(t_err) + data['inliers'].append(inliers) + + +# --- METRIC AGGREGATION --- + +def error_auc(errors, thresholds): + """ + Args: + errors (list): [N,] + thresholds (list) + """ + errors = [0] + sorted(list(errors)) + recall = list(np.linspace(0, 1, len(errors))) + + aucs = [] + thresholds = [5, 10, 20] + for thr in thresholds: + last_index = np.searchsorted(errors, thr) + y = recall[:last_index] + [recall[last_index-1]] + x = errors[:last_index] + [thr] + aucs.append(np.trapz(y, x) / thr) + + return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + + +def epidist_prec(errors, thresholds, ret_dict=False): + precs = [] + for thr in thresholds: + prec_ = [] + for errs in errors: + correct_mask = errs < thr + prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) + precs.append(np.mean(prec_) if len(prec_) > 0 else 0) + if ret_dict: + return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + else: + return precs + + +def aggregate_metrics(metrics, epi_err_thr=5e-4): + """ Aggregate metrics for the whole dataset: + (This method should be called once per dataset) + 1. AUC of the pose error (angular) at the threshold [5, 10, 20] + 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) + """ + # filter duplicates + unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) + unq_ids = list(unq_ids.values()) + logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') + + # pose auc + angular_thresholds = [5, 10, 20] + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) + + # matching precision + dist_thresholds = [epi_err_thr] + precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) + + return {**aucs, **precs} diff --git a/imcui/third_party/XoFTR/src/utils/misc.py b/imcui/third_party/XoFTR/src/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8db04666519753ea2df43903ab6c47ec00a9a1 --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/misc.py @@ -0,0 +1,101 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() + diff --git a/imcui/third_party/XoFTR/src/utils/plotting.py b/imcui/third_party/XoFTR/src/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..09c312223394330125b2b5572b8fae9c7a12f599 --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/plotting.py @@ -0,0 +1,227 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +plt.switch_backend('agg') +from einops.einops import rearrange +import torch.nn.functional as F + + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth': + thr = 1e-4 + elif dataset_name == 'vistir': + thr = 5e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic', ret_dict=None): + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}'] + if ret_dict is not None: + text += [f"t_err: {ret_dict['metrics']['t_errs'][b_id]:.2f}", + f"R_err: {ret_dict['metrics']['R_errs'][b_id]:.2f}"] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_confidence_figure(data, b_id): + # TODO: Implement confidence figure + raise NotImplementedError() + + +def make_matching_figures(data, config, mode='evaluation', ret_dict=None): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_XoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA, ret_dict=ret_dict) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + +def make_mae_figures(data): + """ Make mae figures for a batch. + + Args: + data (Dict): a batch updated by PL_XoFTR_Pretrain. + Returns: + figures (List[plt.figure]) + """ + + scale = data['hw0_i'][0] // data['hw0_f'][0] + W_f = data["W_f"] + + pred0, pred1 = data["pred0"], data["pred1"] + target0, target1 = data["target0"], data["target1"] + + # replace masked regions with predictions + target0[data['b_ids'][data["ids_image0"]], data['i_ids'][data["ids_image0"]]] = pred0[data["ids_image0"]] + target1[data['b_ids'][data["ids_image1"]], data['j_ids'][data["ids_image1"]]] = pred1[data["ids_image1"]] + + # remove excess parts, since the 10x10 windows have overlaping regions + target0 = rearrange(target0, 'n l (h w) (p q c) -> n c (h p) (w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1) + target1 = rearrange(target1, 'n l (h w) (p q c) -> n c (h p) (w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1) + # target0[:,:,-scale:,:] = 0.0 + # target0[:,:,:,-scale:] = 0.0 + # target1[:,:,-scale:,:] = 0.0 + # target1[:,:,:,-scale:] = 0.0 + gap = scale //2 + target0[:,:,-gap:,:] = 0.0 + target0[:,:,:,-gap:] = 0.0 + target1[:,:,-gap:,:] = 0.0 + target1[:,:,:,-gap:] = 0.0 + target0[:,:,:gap,:] = 0.0 + target0[:,:,:,:gap] = 0.0 + target1[:,:,:gap,:] = 0.0 + target1[:,:,:,:gap] = 0.0 + target0 = rearrange(target0, 'n c (h p) (w q) l -> n (c h p w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1) + target1 = rearrange(target1, 'n c (h p) (w q) l -> n (c h p w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1) + + # windows to image + kernel_size = [int(W_f*scale), int(W_f*scale)] + padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2 + stride = data['hw0_i'][0] // data['hw0_c'][0] + target0 = F.fold(target0, output_size=data["image0"].shape[2:], kernel_size=kernel_size, stride=stride, padding=padding) + target1 = F.fold(target1, output_size=data["image1"].shape[2:], kernel_size=kernel_size, stride=stride, padding=padding) + + # add mean and std of original image for visualization + if ("image0_norm" in data) and ("image1_norm" in data): + target0 = target0 * data["image0_std"] + data["image0_mean"] + target1 = target1 * data["image1_std"] + data["image1_mean"] + masked_image0 = data["masked_image0"] * data["image0_std"].to("cpu") + data["image0_mean"].to("cpu") + masked_image1 = data["masked_image1"] * data["image1_std"].to("cpu") + data["image1_mean"].to("cpu") + else: + masked_image0 = data["masked_image0"] + masked_image1 = data["masked_image1"] + + figures = [] + # Create a list of these tensors + image_groups = [[data["image0"], masked_image0, target0], + [data["image1"], masked_image1, target1]] + + # Iterate through the batches + for batch_idx in range(image_groups[0][0].shape[0]): # Assuming batch dimension is the first dimension + fig, axs = plt.subplots(2, 3, figsize=(9, 6)) + for i, image_tensors in enumerate(image_groups): + for j, img_tensor in enumerate(image_tensors): + img = img_tensor[batch_idx, 0, :, :].detach().cpu().numpy() # Get the image data as a NumPy array + axs[i,j].imshow(img, cmap='gray', vmin=0, vmax=1) # Display the image in a subplot with correct colormap + axs[i,j].axis('off') # Turn off axis labels + fig.tight_layout() + figures.append(fig) + return figures + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) diff --git a/imcui/third_party/XoFTR/src/utils/pretrain_utils.py b/imcui/third_party/XoFTR/src/utils/pretrain_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad0296d2c249d6d0adb575422efaf1d01ba69fe --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/pretrain_utils.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange +import torch.nn.functional as F + +def generate_random_masks(batch, patch_size, mask_ratio, generator=None, margins=[0,0,0,0]): + mae_mask0 = _gen_random_mask(batch['image0'], patch_size, mask_ratio, generator, margins=margins) + mae_mask1 = _gen_random_mask(batch['image1'], patch_size, mask_ratio, generator, margins=margins) + batch.update({"mae_mask0" : mae_mask0, "mae_mask1": mae_mask1}) + +def _gen_random_mask(image, patch_size, mask_ratio, generator=None, margins=[0, 0, 0, 0]): + """ Random mask generator + Args: + image (torch.Tensor): [N, C, H, W] + patch_size (int) + mask_ratio (float) + generator (torch.Generator): RNG to create the same random masks for validation + margins [float, float, float, float]: unused part for masking (up bottom left right) + Returns: + mask (torch.Tensor): (N, L) + """ + N = image.shape[0] + l = (image.shape[2] // patch_size) + L = l ** 2 + len_keep = int(L * (1 - mask_ratio * (1 - sum(margins)))) + + margins = [int(margin * l) for margin in margins] + + noise = torch.rand(N, l, l, device=image.device, generator=generator) + if margins[0] > 0 : noise[:,:margins[0],:] = 0 + if margins[1] > 0 : noise[:,-margins[1]:,:] = 0 + if margins[2] > 0 : noise[:,:,:margins[2]] = 0 + if margins[3] > 0 : noise[:,:,-margins[3]:] = 0 + noise = noise.flatten(1) + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # generate the binary mask: 0 is keep 1 is remove + mask = torch.ones([N, L], device=image.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + return mask + +def patchify(data): + """ Split images into small overlapped patches + Args: + data (dict):{ + 'image0_norm' (torch.Tensor): [N, C, H, W] normalized image, + 'image1_norm' (torch.Tensor): [N, C, H, W] normalized image, + Returns: + image0 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows) + image1 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows) + """ + stride = data['hw0_i'][0] // data['hw0_c'][0] + scale = data['hw0_i'][0] // data['hw0_f'][0] + W_f = data["W_f"] + kernel_size = [int(W_f*scale), int(W_f*scale)] + padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2 + + image0 = data["image0_norm"] if "image0_norm" in data else data["image0"] + image1 = data["image1_norm"] if "image1_norm" in data else data["image1"] + + image0 = F.unfold(image0, kernel_size=kernel_size, stride=stride, padding=padding) + image0 = rearrange(image0, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale) + image0 = image0.flatten(4) + image0 = image0.reshape(*image0.shape[:2], W_f**2, -1) + + image1 = F.unfold(image1, kernel_size=kernel_size, stride=stride, padding=padding) + image1 = rearrange(image1, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale) + image1 = image1.flatten(4) + image1 = image1.reshape(*image1.shape[:2], W_f**2, -1) + + return image0, image1 + +def get_target(data): + """Create target patches for mae""" + target0, target1 = patchify(data) + data.update({"target0":target0, "target1":target1}) + + diff --git a/imcui/third_party/XoFTR/src/utils/profiler.py b/imcui/third_party/XoFTR/src/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9 --- /dev/null +++ b/imcui/third_party/XoFTR/src/utils/profiler.py @@ -0,0 +1,39 @@ +import torch +from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler +from contextlib import contextmanager +from pytorch_lightning.utilities import rank_zero_only + + +class InferenceProfiler(SimpleProfiler): + """ + This profiler records duration of actions with cuda.synchronize() + Use this in test time. + """ + + def __init__(self): + super().__init__() + self.start = rank_zero_only(self.start) + self.stop = rank_zero_only(self.stop) + self.summary = rank_zero_only(self.summary) + + @contextmanager + def profile(self, action_name: str) -> None: + try: + torch.cuda.synchronize() + self.start(action_name) + yield action_name + finally: + torch.cuda.synchronize() + self.stop(action_name) + + +def build_profiler(name): + if name == 'inference': + return InferenceProfiler() + elif name == 'pytorch': + from pytorch_lightning.profiler import PyTorchProfiler + return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) + elif name is None: + return PassThroughProfiler() + else: + raise ValueError(f'Invalid profiler: {name}') diff --git a/imcui/third_party/XoFTR/src/xoftr/__init__.py b/imcui/third_party/XoFTR/src/xoftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84ad2ecd63cab23f8ae42d49024eb448a6e18a59 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/__init__.py @@ -0,0 +1,2 @@ +from .xoftr import XoFTR +from .xoftr_pretrain import XoFTR_Pretrain diff --git a/imcui/third_party/XoFTR/src/xoftr/backbone/__init__.py b/imcui/third_party/XoFTR/src/xoftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f84fff9207161b2ef7cda5aeaae5e1e1c9180f --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/backbone/__init__.py @@ -0,0 +1 @@ +from .resnet import ResNet_8_2 diff --git a/imcui/third_party/XoFTR/src/xoftr/backbone/resnet.py b/imcui/third_party/XoFTR/src/xoftr/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..86d6b24fc31ade00e2f434ee4bc35618c45832ed --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/backbone/resnet.py @@ -0,0 +1,95 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class ResNet_8_2(nn.Module): + """ + ResNet, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + x3_out = self.layer3_outconv(x3) + + return x3_out, x2, x1 + diff --git a/imcui/third_party/XoFTR/src/xoftr/utils/geometry.py b/imcui/third_party/XoFTR/src/xoftr/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..23c9a318e1838614aa29de8949c9f2255cd73fdf --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/utils/geometry.py @@ -0,0 +1,107 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 + +@torch.no_grad() +def warp_kpts_fine(kpts0, depth0, depth1, T_0to1, K0, K1, b_ids): + """ Warp kpts0 from I0 to I1 with depth, K and Rt for give batch ids + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + b_ids (torch.Tensor): [M], selected batch ids for fine-level matching + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[b_ids[i], kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0[b_ids].inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[b_ids, :3, :3] @ kpts0_cam + T_0to1[b_ids, :3, [3]][...,None] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1[b_ids] @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[b_ids[i], w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/imcui/third_party/XoFTR/src/xoftr/utils/position_encoding.py b/imcui/third_party/XoFTR/src/xoftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa07546e2f62d73ef2da22b06466adeb01c5444 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/utils/position_encoding.py @@ -0,0 +1,36 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] \ No newline at end of file diff --git a/imcui/third_party/XoFTR/src/xoftr/utils/supervision.py b/imcui/third_party/XoFTR/src/xoftr/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..8924aaec70b5dc40d899f69d61e6b74fba15d7c7 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/utils/supervision.py @@ -0,0 +1,290 @@ +from math import log +from loguru import logger + +import torch +import torch.nn.functional as F +from einops import repeat +from kornia.utils import create_meshgrid +from einops.einops import rearrange +from .geometry import warp_kpts, warp_kpts_fine +from kornia.geometry.epipolar import fundamental_from_projections, normalize_transformation + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['XOFTR']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + valid_mask0, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + valid_mask1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_i[~valid_mask0] = 0 + w_pt1_i[~valid_mask1] = 0 + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round().long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + arange_1 = torch.arange(h0*w0, device=device)[None].repeat(N, 1) + arange_0 = torch.arange(h0*w0, device=device)[None].repeat(N, 1) + arange_1[nearest_index1 == 0] = 0 + arange_0[nearest_index0 == 0] = 0 + arange_b = torch.arange(N, device=device).unsqueeze(1) + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + conf_matrix_gt[arange_b, arange_1, nearest_index1] = 1 + conf_matrix_gt[arange_b, nearest_index0, arange_0] = 1 + conf_matrix_gt[:, 0, 0] = False + + b_ids, i_ids, j_ids = conf_matrix_gt.nonzero(as_tuple=True) + + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +def compute_supervision_fine(data, config): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config) + else: + raise NotImplementedError + +@torch.no_grad() +def create_2d_gaussian_kernel(kernel_size, sigma, device): + """ + Create a 2D Gaussian kernel. + + Args: + kernel_size (int): Size of the kernel (both width and height). + sigma (float): Standard deviation of the Gaussian distribution. + + Returns: + torch.Tensor: 2D Gaussian kernel. + """ + kernel = torch.arange(kernel_size, dtype=torch.float32, device=device) - (kernel_size - 1) / 2 + kernel = torch.exp(-kernel**2 / (2 * sigma**2)) + kernel = kernel / kernel.sum() + + # Outer product to get a 2D kernel + kernel = torch.outer(kernel, kernel) + + return kernel + +@torch.no_grad() +def create_conf_prob(points, h0, w0, h1, w1, kernel_size = 5, sigma=1): + """ + Place a gaussian kernel in sim matrix for warped points + + Args: + data (dict): { + points: (torch.Tensor): (N, L, 2), warped rounded key points + h0, w0, h1, w1: (int), windows sizes + kernel_size: (int), kernel size for the gaussian + sigma: (float), sigma value for gaussian + } + """ + B = points.shape[0] + impulses = torch.zeros(B, h0 * w0, h1, w1, device=points.device) + + # Extract the row and column indices + row_indices = points[:, :, 1] + col_indices = points[:, :, 0] + + # Set the corresponding locations in the target tensor to 1 + impulses[torch.arange(B, device=points.device).view(B, 1, 1), + torch.arange(h0 * w0, device=points.device).view(1, h0 * w0, 1), + row_indices.unsqueeze(-1), col_indices.unsqueeze(-1)] = 1 + # mask 0,0 point + impulses[:,:,0,0] = 0 + + # Create the Gaussian kernel + gaussian_kernel = create_2d_gaussian_kernel(kernel_size, sigma=sigma, device=points.device) + gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) + + # Create distributions at the points + conf_prob = F.conv2d(impulses.view(-1,1,h1,w1), gaussian_kernel, padding=kernel_size//2).view(-1, h0*w0, h1*w1) + + return conf_prob + +@torch.no_grad() +def spvs_fine(data, config): + """ + Args: + data (dict): { + 'b_ids': [M] + 'i_ids': [M] + 'j_ids': [M] + } + + Update: + data (dict): { + conf_matrix_f_gt: [N, W_f^2, W_f^2], in original image resolution + } + + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['XOFTR']['RESOLUTION'][1] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + scale_f_c = config['XOFTR']['RESOLUTION'][0] // config['XOFTR']['RESOLUTION'][1] + W_f = config['XOFTR']['FINE_WINDOW_SIZE'] + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + if len(b_ids) == 0: + data.update({"conf_matrix_f_gt": torch.zeros(1,W_f*W_f,W_f*W_f, device=device)}) + return + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).repeat(N, 1, 1, 1)#.reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0[:,None,...] * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).repeat(N, 1, 1, 1)#.reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1[:,None,...] * grid_pt1_c + + # unfold (crop windows) all local windows + stride_f = data['hw0_f'][0] // data['hw0_c'][0] + + grid_pt0_i = rearrange(grid_pt0_i, 'n h w c -> n c h w') + grid_pt0_i = F.unfold(grid_pt0_i, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2) + grid_pt0_i = rearrange(grid_pt0_i, 'n (c ww) l -> n l ww c', ww=W_f**2) + grid_pt0_i = grid_pt0_i[b_ids, i_ids] + + grid_pt1_i = rearrange(grid_pt1_i, 'n h w c -> n c h w') + grid_pt1_i = F.unfold(grid_pt1_i, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2) + grid_pt1_i = rearrange(grid_pt1_i, 'n (c ww) l -> n l ww c', ww=W_f**2) + grid_pt1_i = grid_pt1_i[b_ids, j_ids] + + # warp kpts bi-directionally and resize them to fine-level resolution + # (no depth consistency check + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts_fine(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'], b_ids) + _, w_pt1_i = warp_kpts_fine(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'], b_ids) + w_pt0_f = w_pt0_i / scale1[b_ids] + w_pt1_f = w_pt1_i / scale0[b_ids] + + mkpts0_c_scaled_to_f = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale_f_c - W_f//2 + mkpts1_c_scaled_to_f = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale_f_c - W_f//2 + + w_pt0_f = w_pt0_f - mkpts1_c_scaled_to_f[:,None,:] + w_pt1_f = w_pt1_f - mkpts0_c_scaled_to_f[:,None,:] + + # 3. check if mutual nearest neighbor + w_pt0_f_round = w_pt0_f[:, :, :].round().long() + w_pt1_f_round = w_pt1_f[:, :, :].round().long() + M = w_pt0_f.shape[0] + + nearest_index1 = w_pt0_f_round[..., 0] + w_pt0_f_round[..., 1] * W_f + nearest_index0 = w_pt1_f_round[..., 0] + w_pt1_f_round[..., 1] * W_f + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_f_round, W_f, W_f)] = 0 + nearest_index0[out_bound_mask(w_pt1_f_round, W_f, W_f)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(W_f*W_f, device=device)[None].repeat(M, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_f_gt = torch.zeros(M, W_f*W_f, W_f*W_f, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + conf_matrix_f_gt[b_ids, i_ids, j_ids] = 1 + + data.update({"conf_matrix_f_gt": conf_matrix_f_gt}) + + diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr.py b/imcui/third_party/XoFTR/src/xoftr/xoftr.py new file mode 100644 index 0000000000000000000000000000000000000000..78bc7c867d82a7161ce38fd46230c04c8d26b60d --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange +from .backbone import ResNet_8_2 +from .utils.position_encoding import PositionEncodingSine +from .xoftr_module import LocalFeatureTransformer, FineProcess, CoarseMatching, FineSubMatching + +class XoFTR(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = ResNet_8_2(config['resnet']) + self.pos_encoding = PositionEncodingSine(config['coarse']['d_model']) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_process = FineProcess(config) + self.fine_matching= FineSubMatching(config) + + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + eps = 1e-6 + + image0_mean = data['image0'].mean(dim=[2,3], keepdim=True) + image0_std = data['image0'].std(dim=[2,3], keepdim=True) + image0 = (data['image0'] - image0_mean) / (image0_std + eps) + + image1_mean = data['image1'].mean(dim=[2,3], keepdim=True) + image1_std = data['image1'].std(dim=[2,3], keepdim=True) + image1 = (data['image1'] - image1_mean) / (image1_std + eps) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_m, feats_f = self.backbone(torch.cat([image0, image1], dim=0)) + (feat_c0, feat_c1) = feats_c.split(data['bs']) + (feat_m0, feat_m1) = feats_m.split(data['bs']) + (feat_f0, feat_f1) = feats_f.split(data['bs']) + else: # handle different input shapes + feat_c0, feat_m0, feat_f0 = self.backbone(image0) + feat_c1, feat_m1, feat_f1 = self.backbone(image1) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_m': feat_m0.shape[2:], 'hw1_m': feat_m1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # save coarse features for fine matching + feat_c0_pre, feat_c1_pre = feat_c0.clone(), feat_c1.clone() + + # 2. coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level matching module + feat_f0_unfold, feat_f1_unfold = self.fine_process(feat_f0, feat_f1, + feat_m0, feat_m1, + feat_c0, feat_c1, + feat_c0_pre, feat_c1_pre, + data) + + # 5. match fine-level and sub-pixel refinement + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a99d3559e31f93291afbcf65ade17cf3616bbc9 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py @@ -0,0 +1,4 @@ +from .transformer import LocalFeatureTransformer +from .fine_process import FineProcess +from .coarse_matching import CoarseMatching +from .fine_matching import FineSubMatching diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..34f82abbbc912a8761cf2858183041ed17d7cd70 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py @@ -0,0 +1,305 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + d_model = config['d_model'] + self.thr = config['thr'] + self.inference = config['inference'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level XoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + self.final_proj = nn.Linear(d_model, d_model, bias=True) + + self.temperature = config['dsmax_temperature'] + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + + feat_c0 = self.final_proj(feat_c0) + feat_c1 = self.final_proj(feat_c1) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + if self.inference: + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match_inference(sim_matrix, data)) + else: + conf_matrix_0_to_1 = F.softmax(sim_matrix, 2) + conf_matrix_1_to_0 = F.softmax(sim_matrix, 1) + data.update({'conf_matrix_0_to_1': conf_matrix_0_to_1, + 'conf_matrix_1_to_0': conf_matrix_1_to_0 + }) + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match_training(conf_matrix_0_to_1, conf_matrix_1_to_0, data)) + + @torch.no_grad() + def get_coarse_match_training(self, conf_matrix_0_to_1, conf_matrix_1_to_0, data): + """ + Args: + conf_matrix_0_to_1 (torch.Tensor): [N, L, S] + conf_matrix_1_to_0 (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix_0_to_1.device + + # confidence thresholding + # {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)} + mask = torch.logical_or((conf_matrix_0_to_1 > self.thr) * (conf_matrix_0_to_1 == conf_matrix_0_to_1.max(dim=2, keepdim=True)[0]), + (conf_matrix_1_to_0 > self.thr) * (conf_matrix_1_to_0 == conf_matrix_1_to_0.max(dim=1, keepdim=True)[0])) + + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # find all valid coarse matches + b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True) + + mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids]) + + # random sampling of training samples for fine-level XoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # the sampling is performed across all pairs in a batch without manually balancing + # samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # these matches are selected patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')], + dim=1) * scale1 + + # these matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches + + @torch.no_grad() + def get_coarse_match_inference(self, sim_matrix, data): + """ + Args: + sim_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + + # softmax for 0 to 1 + conf_matrix_ = F.softmax(sim_matrix, 2) + + # confidence thresholding and nearest neighbour for 0 to 1 + mask = (conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=2, keepdim=True)[0]) + + # unlike training, reuse the same conf martix to decrease the vram consumption + # softmax for 0 to 1 + conf_matrix_ = F.softmax(sim_matrix, 1) + + # update mask {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)} + mask = torch.logical_or(mask, + (conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=1, keepdim=True)[0])) + + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # find all valid coarse matches + b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True) + + # mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids]) + + # these matches are selected patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')], + dim=1) * scale1 + + # these matches are the current coarse level predictions + coarse_matches.update({ + 'm_bids': b_ids, # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c, + 'mkpts1_c': mkpts1_c, + }) + + return coarse_matches diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..5430f71b0bedd60a613226e781ed8ce2baf7f9f2 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FineSubMatching(nn.Module): + """Fine-level and Sub-pixel matching""" + + def __init__(self, config): + super().__init__() + self.temperature = config['fine']['dsmax_temperature'] + self.W_f = config['fine_window_size'] + self.denser = config['fine']['denser'] + self.inference = config['fine']['inference'] + dim_f = config['resnet']['block_dims'][0] + self.fine_thr = config['fine']['thr'] + self.fine_proj = nn.Linear(dim_f, dim_f, bias=False) + self.subpixel_mlp = nn.Sequential(nn.Linear(2*dim_f, 2*dim_f, bias=False), + nn.ReLU(), + nn.Linear(2*dim_f, 4, bias=False)) + + def forward(self, feat_f0_unfold, feat_f1_unfold, data): + """ + Args: + feat_f0_unfold (torch.Tensor): [M, WW, C] + feat_f1_unfold (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + + feat_f0 = self.fine_proj(feat_f0_unfold) + feat_f1 = self.fine_proj(feat_f1_unfold) + + M, WW, C = feat_f0.shape + W_f = self.W_f + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + 'mconf_f': torch.zeros(0, device=feat_f0_unfold.device), + # 'mkpts0_f_train': data['mkpts0_c'], + # 'mkpts1_f_train': data['mkpts1_c'], + # 'conf_matrix_fine': torch.zeros(1, W_f*W_f, W_f*W_f, device=feat_f0.device) + }) + return + + # normalize + feat_f0, feat_f1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_f0, feat_f1]) + sim_matrix = torch.einsum("nlc,nsc->nls", feat_f0, + feat_f1) / self.temperature + + conf_matrix_fine = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + data.update({'conf_matrix_fine': conf_matrix_fine}) + + # predict fine-level and sub-pixel matches from conf_matrix + data.update(**self.get_fine_sub_match(conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data)) + + def get_fine_sub_match(self, conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data): + """ + Args: + conf_matrix_fine (torch.Tensor): [M, WW, WW] + feat_f0_unfold (torch.Tensor): [M, WW, C] + feat_f1_unfold (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'm_bids' (torch.Tensor): [M] + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + + with torch.no_grad(): + W_f = self.W_f + + # 1. confidence thresholding + mask = conf_matrix_fine > self.fine_thr + + if mask.sum() == 0: + mask[0,0,0] = 1 + conf_matrix_fine[0,0,0] = 1 + + if not self.denser: + # match only the highest confidence + mask = mask \ + * (conf_matrix_fine == conf_matrix_fine.amax(dim=[1,2], keepdim=True)) + else: + # 2. mutual nearest, match all features in fine window + mask = mask \ + * (conf_matrix_fine == conf_matrix_fine.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix_fine == conf_matrix_fine.max(dim=1, keepdim=True)[0]) + + # 3. find all valid fine matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix_fine[b_ids, i_ids, j_ids] + + # 4. update with matches in original image resolution + + # indices from coarse matches + b_ids_c, i_ids_c, j_ids_c = data['b_ids'], data['i_ids'], data['j_ids'] + + # scale (coarse level / fine-level) + scale_f_c = data['hw0_f'][0] // data['hw0_c'][0] + + # coarse level matches scaled to fine-level (1/2) + mkpts0_c_scaled_to_f = torch.stack( + [i_ids_c % data['hw0_c'][1], torch.div(i_ids_c, data['hw0_c'][1], rounding_mode='trunc')], + dim=1) * scale_f_c + + mkpts1_c_scaled_to_f = torch.stack( + [j_ids_c % data['hw1_c'][1], torch.div(j_ids_c, data['hw1_c'][1], rounding_mode='trunc')], + dim=1) * scale_f_c + + # updated b_ids after second thresholding + updated_b_ids = b_ids_c[b_ids] + + # scales (image res / fine level) + scale = data['hw0_i'][0] / data['hw0_f'][0] + scale0 = scale * data['scale0'][updated_b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][updated_b_ids] if 'scale1' in data else scale + + # fine-level discrete matches on window coordiantes + mkpts0_f_window = torch.stack( + [i_ids % W_f, torch.div(i_ids, W_f, rounding_mode='trunc')], + dim=1) + + mkpts1_f_window = torch.stack( + [j_ids % W_f, torch.div(j_ids, W_f, rounding_mode='trunc')], + dim=1) + + # sub-pixel refinement + sub_ref = self.subpixel_mlp(torch.cat([feat_f0_unfold[b_ids, i_ids], + feat_f1_unfold[b_ids, j_ids]], dim=-1)) + sub_ref0, sub_ref1 = torch.chunk(sub_ref, 2, dim=-1) + sub_ref0 = torch.tanh(sub_ref0) * 0.5 + sub_ref1 = torch.tanh(sub_ref1) * 0.5 + + # final sub-pixel matches by (coarse-level + fine-level windowed + sub-pixel refinement) + mkpts0_f_train = (mkpts0_f_window + mkpts0_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref0) * scale0 + mkpts1_f_train = (mkpts1_f_window + mkpts1_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref1) * scale1 + mkpts0_f = mkpts0_f_train.clone().detach() + mkpts1_f = mkpts1_f_train.clone().detach() + + # These matches is the current prediction (for visualization) + sub_pixel_matches = { + 'm_bids': b_ids_c[b_ids[mconf != 0]], # mconf == 0 => gt matches + 'mkpts0_f': mkpts0_f[mconf != 0], + 'mkpts1_f': mkpts1_f[mconf != 0], + 'mconf_f': mconf[mconf != 0] + } + + # These matches are used for training + if not self.inference: + sub_pixel_matches.update({ + 'mkpts0_f_train': mkpts0_f_train[mconf != 0], + 'mkpts1_f_train': mkpts1_f_train[mconf != 0], + }) + + return sub_pixel_matches diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py new file mode 100644 index 0000000000000000000000000000000000000000..001ab663b1fdd767c407e8089137a142012caa17 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py @@ -0,0 +1,321 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +class Mlp(nn.Module): + """Multi-Layer Perceptron (MLP)""" + + def __init__(self, + in_dim, + hidden_dim=None, + out_dim=None, + act_layer=nn.GELU): + """ + Args: + in_dim: input features dimension + hidden_dim: hidden features dimension + out_dim: output features dimension + act_layer: activation function + """ + super().__init__() + out_dim = out_dim or in_dim + hidden_dim = hidden_dim or in_dim + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.out_dim = out_dim + + def forward(self, x): + x_size = x.size() + x = x.view(-1, x_size[-1]) + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = x.view(*x_size[:-1], self.out_dim) + return x + + +class VanillaAttention(nn.Module): + def __init__(self, + dim, + num_heads=8, + proj_bias=False): + super().__init__() + """ + Args: + dim: feature dimension + num_heads: number of attention head + proj_bias: bool use query, key, value bias + """ + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.softmax_temp = self.head_dim ** -0.5 + self.kv_proj = nn.Linear(dim, dim * 2, bias=proj_bias) + self.q_proj = nn.Linear(dim, dim, bias=proj_bias) + self.merge = nn.Linear(dim, dim) + + def forward(self, x_q, x_kv=None): + """ + Args: + x_q (torch.Tensor): [N, L, C] + x_kv (torch.Tensor): [N, S, C] + """ + if x_kv is None: + x_kv = x_q + bs, _, dim = x_q.shape + bs, _, dim = x_kv.shape + # [N, S, 2, H, D] => [2, N, H, S, D] + kv = self.kv_proj(x_kv).reshape(bs, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + # [N, L, H, D] => [N, H, L, D] + q = self.q_proj(x_q).reshape(bs, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k, v = kv[0].transpose(-2, -1).contiguous(), kv[1].contiguous() # [N, H, D, S], [N, H, S, D] + attn = (q @ k) * self.softmax_temp # [N, H, L, S] + attn = attn.softmax(dim=-1) + x_q = (attn @ v).transpose(1, 2).reshape(bs, -1, dim) + x_q = self.merge(x_q) + return x_q + + +class CrossBidirectionalAttention(nn.Module): + def __init__(self, dim, num_heads, proj_bias = False): + super().__init__() + """ + Args: + dim: feature dimension + num_heads: number of attention head + proj_bias: bool use query, key, value bias + """ + + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.softmax_temp = self.head_dim ** -0.5 + self.qk_proj = nn.Linear(dim, dim, bias=proj_bias) + self.v_proj = nn.Linear(dim, dim, bias=proj_bias) + self.merge = nn.Linear(dim, dim, bias=proj_bias) + self.temperature = nn.Parameter(torch.tensor([0.0]), requires_grad=True) + # print(self.temperature) + + def map_(self, func, x0, x1): + return func(x0), func(x1) + + def forward(self, x0, x1): + """ + Args: + x0 (torch.Tensor): [N, L, C] + x1 (torch.Tensor): [N, S, C] + """ + bs = x0.size(0) + + qk0, qk1 = self.map_(self.qk_proj, x0, x1) + v0, v1 = self.map_(self.v_proj, x0, x1) + qk0, qk1, v0, v1 = map( + lambda t: t.reshape(bs, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous(), + (qk0, qk1, v0, v1)) + + qk0, qk1 = qk0 * self.softmax_temp**0.5, qk1 * self.softmax_temp**0.5 + sim = qk0 @ qk1.transpose(-2,-1).contiguous() + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + x0 = attn01 @ v1 + x1 = attn10 @ v0 + x0, x1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), + x0, x1) + x0, x1 = self.map_(self.merge, x0, x1) + + return x0, x1 + + +class SwinPosEmbMLP(nn.Module): + def __init__(self, + dim): + super().__init__() + self.pos_embed = None + self.pos_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(), + nn.Linear(512, dim, bias=False)) + + def forward(self, x): + seq_length = x.shape[1] + if self.pos_embed is None or self.training: + seq_length = int(seq_length**0.5) + coords = torch.arange(0, seq_length, device=x.device, dtype = x.dtype) + grid = torch.stack(torch.meshgrid([coords, coords])).contiguous().unsqueeze(0) + grid -= seq_length // 2 + grid /= (seq_length // 2) + self.pos_embed = self.pos_mlp(grid.flatten(2).transpose(1,2)) + x = x + self.pos_embed + return x + + +class WindowSelfAttention(nn.Module): + def __init__(self, dim, num_heads, mlp_hidden_coef, use_pre_pos_embed=False): + super().__init__() + self.mlp = Mlp(in_dim=dim*2, hidden_dim=dim*mlp_hidden_coef, out_dim=dim, act_layer=nn.GELU) + self.gamma = nn.Parameter(torch.ones(dim)) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.attn = VanillaAttention(dim, num_heads=num_heads) + self.pos_embed = SwinPosEmbMLP(dim) + self.pos_embed_pre = SwinPosEmbMLP(dim) if use_pre_pos_embed else nn.Identity() + + def forward(self, x, x_pre): + ww = x.shape[1] + ww_pre = x_pre.shape[1] + x = self.pos_embed(x) + x_pre = self.pos_embed_pre(x_pre) + x = torch.cat((x, x_pre), dim=1) + x = x + self.gamma*self.norm1(self.mlp(torch.cat([x, self.attn(self.norm2(x))], dim=-1))) + x, x_pre = x.split([ww, ww_pre], dim=1) + return x, x_pre + + +class WindowCrossAttention(nn.Module): + def __init__(self, dim, num_heads, mlp_hidden_coef): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.mlp = Mlp(in_dim=dim*2, hidden_dim=dim*mlp_hidden_coef, out_dim=dim, act_layer=nn.GELU) + self.cross_attn = CrossBidirectionalAttention(dim, num_heads=num_heads, proj_bias=False) + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x0, x1): + m_x0, m_x1 = self.cross_attn(self.norm1(x0), self.norm1(x1)) + x0 = x0 + self.gamma*self.norm2(self.mlp(torch.cat([x0, m_x0], dim=-1))) + x1 = x1 + self.gamma*self.norm2(self.mlp(torch.cat([x1, m_x1], dim=-1))) + return x0, x1 + + +class FineProcess(nn.Module): + def __init__(self, config): + super().__init__() + # Config + block_dims = config['resnet']['block_dims'] + self.block_dims = block_dims + self.W_f = config['fine_window_size'] + self.W_m = config['medium_window_size'] + nhead_f = config["fine"]['nhead_fine_level'] + nhead_m = config["fine"]['nhead_medium_level'] + mlp_hidden_coef = config["fine"]['mlp_hidden_dim_coef'] + + # Networks + self.conv_merge = nn.Sequential(nn.Conv2d(block_dims[2]*2, block_dims[1], kernel_size=1, stride=1, padding=0, bias=False), + nn.Conv2d(block_dims[1], block_dims[1], kernel_size=3, stride=1, padding=1, groups=block_dims[1], bias=False), + nn.BatchNorm2d(block_dims[1]) + ) + self.out_conv_m = nn.Conv2d(block_dims[1], block_dims[1], kernel_size=1, stride=1, padding=0, bias=False) + self.out_conv_f = nn.Conv2d(block_dims[0], block_dims[0], kernel_size=1, stride=1, padding=0, bias=False) + self.self_attn_m = WindowSelfAttention(block_dims[1], num_heads=nhead_m, + mlp_hidden_coef=mlp_hidden_coef, use_pre_pos_embed=False) + self.cross_attn_m = WindowCrossAttention(block_dims[1], num_heads=nhead_m, + mlp_hidden_coef=mlp_hidden_coef) + self.self_attn_f = WindowSelfAttention(block_dims[0], num_heads=nhead_f, + mlp_hidden_coef=mlp_hidden_coef, use_pre_pos_embed=True) + self.cross_attn_f = WindowCrossAttention(block_dims[0], num_heads=nhead_f, + mlp_hidden_coef=mlp_hidden_coef) + self.down_proj_m_f = nn.Linear(block_dims[1], block_dims[0], bias=False) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def pre_process(self, feat_f0, feat_f1, feat_m0, feat_m1, feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data): + W_f = self.W_f + W_m = self.W_m + data.update({'W_f': W_f, + 'W_m': W_m}) + + # merge coarse features before and after loftr layer, and down proj channel dimesions + feat_c0 = rearrange(feat_c0, 'n (h w) c -> n c h w', h =data["hw0_c"][0], w =data["hw0_c"][1]) + feat_c1 = rearrange(feat_c1, 'n (h w) c -> n c h w', h =data["hw1_c"][0], w =data["hw1_c"][1]) + feat_c0 = self.conv_merge(torch.cat([feat_c0, feat_c0_pre], dim=1)) + feat_c1 = self.conv_merge(torch.cat([feat_c1, feat_c1_pre], dim=1)) + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) 1 c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) 1 c') + + stride_f = data['hw0_f'][0] // data['hw0_c'][0] + stride_m = data['hw0_m'][0] // data['hw0_c'][0] + + if feat_m0.shape[2] == feat_m1.shape[2] and feat_m0.shape[3] == feat_m1.shape[3]: + feat_m = self.out_conv_m(torch.cat([feat_m0, feat_m1], dim=0)) + feat_m0, feat_m1 = torch.chunk(feat_m, 2, dim=0) + feat_f = self.out_conv_f(torch.cat([feat_f0, feat_f1], dim=0)) + feat_f0, feat_f1 = torch.chunk(feat_f, 2, dim=0) + else: + feat_m0 = self.out_conv_m(feat_m0) + feat_m1 = self.out_conv_m(feat_m1) + feat_f0 = self.out_conv_f(feat_f0) + feat_f1 = self.out_conv_f(feat_f1) + + # 1. unfold (crop windows) all local windows + feat_m0_unfold = F.unfold(feat_m0, kernel_size=(W_m, W_m), stride=stride_m, padding=W_m//2) + feat_m0_unfold = rearrange(feat_m0_unfold, 'n (c ww) l -> n l ww c', ww=W_m**2) + feat_m1_unfold = F.unfold(feat_m1, kernel_size=(W_m, W_m), stride=stride_m, padding=W_m//2) + feat_m1_unfold = rearrange(feat_m1_unfold, 'n (c ww) l -> n l ww c', ww=W_m**2) + + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W_f**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W_f**2) + + # 2. select only the predicted matches + feat_c0 = feat_c0[data['b_ids'], data['i_ids']] # [n, ww, cm] + feat_c1 = feat_c1[data['b_ids'], data['j_ids']] + + feat_m0_unfold = feat_m0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cm] + feat_m1_unfold = feat_m1_unfold[data['b_ids'], data['j_ids']] + + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + return feat_c0, feat_c1, feat_m0_unfold, feat_m1_unfold, feat_f0_unfold, feat_f1_unfold + + def forward(self, feat_f0, feat_f1, feat_m0, feat_m1, feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data): + """ + Args: + feat_f0 (torch.Tensor): [N, C, H, W] + feat_f1 (torch.Tensor): [N, C, H, W] + feat_m0 (torch.Tensor): [N, C, H, W] + feat_m1 (torch.Tensor): [N, C, H, W] + feat_c0 (torch.Tensor): [N, L, C] + feat_c1 (torch.Tensor): [N, S, C] + feat_c0_pre (torch.Tensor): [N, C, H, W] + feat_c1_pre (torch.Tensor): [N, C, H, W] + data (dict): with keys ['hw0_c', 'hw1_c', 'hw0_m', 'hw1_m', 'hw0_f', 'hw1_f', 'b_ids', 'j_ids'] + """ + + # TODO: Check for this case + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W_f**2, self.block_dims[0], device=feat_f0.device) + feat1 = torch.empty(0, self.W_f**2, self.block_dims[0], device=feat_f0.device) + return feat0, feat1 + + feat_c0, feat_c1, feat_m0_unfold, feat_m1_unfold, \ + feat_f0_unfold, feat_f1_unfold = self.pre_process(feat_f0, feat_f1, feat_m0, feat_m1, + feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data) + + # self attention (c + m) + feat_m_unfold, _ = self.self_attn_m(torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0), + torch.cat([feat_c0, feat_c1], dim=0)) + feat_m0_unfold, feat_m1_unfold = torch.chunk(feat_m_unfold, 2, dim=0) + + # cross attention (m0 <-> m1) + feat_m0_unfold, feat_m1_unfold = self.cross_attn_m(feat_m0_unfold, feat_m1_unfold) + + # down proj m + feat_m_unfold = self.down_proj_m_f(torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0)) + feat_m0_unfold, feat_m1_unfold = torch.chunk(feat_m_unfold, 2, dim=0) + + # self attention (m + f) + feat_f_unfold, _ = self.self_attn_f(torch.cat([feat_f0_unfold, feat_f1_unfold], dim=0), + torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_f_unfold, 2, dim=0) + + # cross attention (f0 <-> f1) + feat_f0_unfold, feat_f1_unfold = self.cross_attn_f(feat_f0_unfold, feat_f1_unfold) + + return feat_f0_unfold, feat_f1_unfold + diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..61b1b8573e6454b6d340c20381ad5f945d479791 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() \ No newline at end of file diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3eef7521a2c33e26fc66ca6797e842a7a146c21f --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py @@ -0,0 +1,101 @@ +import copy +import torch +import torch.nn as nn +from .linear_attention import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 \ No newline at end of file diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_pretrain.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4934eeb2d856d72010840507dde5b2bd1d3a94 --- /dev/null +++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_pretrain.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange +from .backbone import ResNet_8_2 +from .utils.position_encoding import PositionEncodingSine +from .xoftr_module import LocalFeatureTransformer, FineProcess + + +class XoFTR_Pretrain(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + self.patch_size = config["pretrain_patch_size"] + + # Modules + self.backbone = ResNet_8_2(config['resnet']) + self.pos_encoding = PositionEncodingSine(config['coarse']['d_model']) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.fine_process = FineProcess(config) + self.mask_token_f = nn.Parameter(torch.zeros(1, config['resnet']["block_dims"][0], 1, 1)) + self.mask_token_m = nn.Parameter(torch.zeros(1, config['resnet']["block_dims"][1], 1, 1)) + self.mask_token_c = nn.Parameter(torch.zeros(1, config['resnet']["block_dims"][2], 1, 1)) + self.out_proj = nn.Linear(config['resnet']["block_dims"][0], 4) + + torch.nn.init.normal_(self.mask_token_f, std=.02) + torch.nn.init.normal_(self.mask_token_m, std=.02) + torch.nn.init.normal_(self.mask_token_c, std=.02) + + def upsample_mae_mask(self, mae_mask, scale): + assert len(mae_mask.shape) == 2 + p = int(mae_mask.shape[1] ** .5) + return mae_mask.reshape(-1, p, p).repeat_interleave(scale, axis=1).repeat_interleave(scale, axis=2) + + def upsample_mask(self, mask, scale): + return mask.repeat_interleave(scale, axis=1).repeat_interleave(scale, axis=2) + + def mask_layer(self, feat, mae_mask, mae_mask_scale, mask=None, mask_scale=None, mask_token=None): + """ Mask the feature map and replace with trainable inpu tokens if available + Args: + feat (torch.Tensor): [N, C, H, W] + mae_mask (torch.Tensor): (N, L) mask for masked image modeling + mae_mask_scale (int): the scale of layer to mae mask + mask (torch.Generator): mask for padded input image + mask_scale (int): the scale of layer to mask (mask is created on course scale) + mask_token (torch.Tensor): [1, C, 1, 1] learnable mae mask token + Returns: + feat (torch.Tensor): [N, C, H, W] + """ + mae_mask = self.upsample_mae_mask(mae_mask, mae_mask_scale) + mae_mask = mae_mask.unsqueeze(1).type_as(feat) + if mask is not None: + mask = self.upsample_mask(mask, mask_scale) + mask = mask.unsqueeze(1).type_as(feat) + mae_mask = mask * mae_mask + feat = feat * (1. - mae_mask) + if mask_token is not None: + mask_token = mask_token.repeat(feat.shape[0], 1, feat.shape[2], feat.shape[3]) + feat += mask_token * mae_mask + return feat + + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + image0 = data["image0_norm"] if "image0_norm" in data else data["image0"] + image1 = data["image1_norm"] if "image1_norm" in data else data["image1"] + + mask0 = mask1 = None # mask fro madded images + if 'mask0' in data: + mask0, mask1 = data['mask0'], data['mask1'] + + # mask input images + image0 = self.mask_layer(image0, + data["mae_mask0"], + mae_mask_scale=self.patch_size, + mask=mask0, + mask_scale=8) + image1 = self.mask_layer(image1, + data["mae_mask1"], + mae_mask_scale=self.patch_size, + mask=mask1, + mask_scale=8) + data.update({"masked_image0":image0.clone().detach().cpu(), + "masked_image1":image1.clone().detach().cpu()}) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_m, feats_f = self.backbone(torch.cat([image0, image1], dim=0)) + (feat_c0, feat_c1) = feats_c.split(data['bs']) + (feat_m0, feat_m1) = feats_m.split(data['bs']) + (feat_f0, feat_f1) = feats_f.split(data['bs']) + else: # handle different input shapes + feat_c0, feat_m0, feat_f0 = self.backbone(image0) + feat_c1, feat_m1, feat_f1 = self.backbone(image1) + + # mask output layers of backbone and replace with trainable token + feat_c0 = self.mask_layer(feat_c0, + data["mae_mask0"], + mae_mask_scale=self.patch_size // 8, + mask=mask0, + mask_scale=1, + mask_token=self.mask_token_c) + feat_c1 = self.mask_layer(feat_c1, + data["mae_mask1"], + mae_mask_scale=self.patch_size // 8, + mask=mask1, + mask_scale=1, + mask_token=self.mask_token_c) + feat_m0 = self.mask_layer(feat_m0, + data["mae_mask0"], + mae_mask_scale=self.patch_size // 4, + mask=mask0, + mask_scale=2, + mask_token=self.mask_token_m) + feat_m1 = self.mask_layer(feat_m1, + data["mae_mask1"], + mae_mask_scale=self.patch_size // 4, + mask=mask1, + mask_scale=2, + mask_token=self.mask_token_m) + feat_f0 = self.mask_layer(feat_f0, + data["mae_mask0"], + mae_mask_scale=self.patch_size // 2, + mask=mask0, + mask_scale=4, + mask_token=self.mask_token_f) + feat_f1 = self.mask_layer(feat_f1, + data["mae_mask1"], + mae_mask_scale=self.patch_size // 2, + mask=mask1, + mask_scale=4, + mask_token=self.mask_token_f) + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_m': feat_m0.shape[2:], 'hw1_m': feat_m1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # save coarse features for fine matching module + feat_c0_pre, feat_c1_pre = feat_c0.clone(), feat_c1.clone() + + # 2. Coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. Fine-level maching module as decoder + # generate window locations from mae mask to reconstruct + mae_mask_c0 = self.upsample_mae_mask( data["mae_mask0"], + self.patch_size // 8) + if mask0 is not None: + mae_mask_c0 = mae_mask_c0 * mask0.type_as(mae_mask_c0) + + mae_mask_c1 = self.upsample_mae_mask( data["mae_mask1"], + self.patch_size // 8) + if mask1 is not None: + mae_mask_c1 = mae_mask_c1 * mask1.type_as(mae_mask_c1) + + mae_mask_c = torch.logical_or(mae_mask_c0, mae_mask_c1) + + b_ids, i_ids = mae_mask_c.flatten(1).nonzero(as_tuple=True) + j_ids = i_ids + + # b_ids, i_ids and j_ids are masked location for both images + # ids_image0 and ids_image1 determines which indeces belogs to which image + ids_image0 = mae_mask_c0.flatten(1)[b_ids, i_ids] + ids_image1 = mae_mask_c1.flatten(1)[b_ids, j_ids] + + data.update({'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids, + 'ids_image0': ids_image0==1, 'ids_image1': ids_image1==1}) + + + # fine level matching module + feat_f0_unfold, feat_f1_unfold = self.fine_process( feat_f0, feat_f1, + feat_m0, feat_m1, + feat_c0, feat_c1, + feat_c0_pre, feat_c1_pre, + data) + + # output projection 5x5 window to 10x10 window + pred0 = self.out_proj(feat_f0_unfold) + pred1 = self.out_proj(feat_f1_unfold) + + data.update({"pred0":pred0, "pred1": pred1}) + + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/imcui/third_party/XoFTR/test.py b/imcui/third_party/XoFTR/test.py new file mode 100644 index 0000000000000000000000000000000000000000..133a146d92f5d3ebc59fc1fbad6551aab4b3f5b4 --- /dev/null +++ b/imcui/third_party/XoFTR/test.py @@ -0,0 +1,68 @@ +import pytorch_lightning as pl +import argparse +import pprint +from loguru import logger as loguru_logger + +from src.config.default import get_cfg_defaults +from src.utils.profiler import build_profiler + +from src.lightning.data import MultiSceneDataModule +from src.lightning.lightning_xoftr import PL_XoFTR + + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') + parser.add_argument( + '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") + parser.add_argument( + '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--batch_size', type=int, default=1, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=2) + parser.add_argument( + '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +if __name__ == '__main__': + # parse arguments + args = parse_args() + pprint.pprint(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + pl.seed_everything(config.TRAINER.SEED) # reproducibility + + # tune when testing + if args.thr is not None: + config.XoFTR.MATCH_COARSE.THR = args.thr + + loguru_logger.info(f"Args and config initialized!") + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_XoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) + loguru_logger.info(f"XoFTR-lightning initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"DataModule initialized!") + + # lightning trainer + trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) + + loguru_logger.info(f"Start testing!") + trainer.test(model, datamodule=data_module, verbose=False) diff --git a/imcui/third_party/XoFTR/test_relative_pose.py b/imcui/third_party/XoFTR/test_relative_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb8c7ce58d46623ff9e06b784ac846355b537b9 --- /dev/null +++ b/imcui/third_party/XoFTR/test_relative_pose.py @@ -0,0 +1,330 @@ +from collections import defaultdict, OrderedDict +import os +import os.path as osp +import numpy as np +from tqdm import tqdm +import argparse +import cv2 +from pathlib import Path +import warnings +import json +import time + +from src.utils.metrics import estimate_pose, relative_pose_error, error_auc, symmetric_epipolar_distance_numpy +from src.utils.plotting import dynamic_alpha, error_colormap, make_matching_figure + + +# Loading functions for methods +#################################################################### +def load_xoftr(args): + from src.xoftr import XoFTR + from src.config.default import get_cfg_defaults + from src.utils.data_io import DataIOWrapper, lower_config + config = get_cfg_defaults(inference=True) + config = lower_config(config) + config["xoftr"]["match_coarse"]["thr"] = args.match_threshold + config["xoftr"]["fine"]["thr"] = args.fine_threshold + ckpt = args.ckpt + matcher = XoFTR(config=config["xoftr"]) + matcher = DataIOWrapper(matcher, config=config["test"], ckpt=ckpt) + return matcher.from_paths + +#################################################################### + +def load_vis_tir_pairs_npz(npz_root, npz_list): + """Load information for scene and image pairs from npz files. + Args: + npz_root: Directory path for npz files + npz_list: File containing the names of the npz files to be used + """ + with open(npz_list, 'r') as f: + npz_names = [name.split()[0] for name in f.readlines()] + print(f"Parse {len(npz_names)} npz from {npz_list}.") + + total_pairs = 0 + scene_pairs = {} + + for name in npz_names: + print(f"Loading {name}") + scene_info = np.load(f"{npz_root}/{name}", allow_pickle=True) + pairs = [] + + # Collect pairs + for pair_info in scene_info['pair_infos']: + total_pairs += 1 + (id0, id1) = pair_info + im0 = scene_info['image_paths'][id0][0] + im1 = scene_info['image_paths'][id1][1] + K0 = scene_info['intrinsics'][id0][0].astype(np.float32) + K1 = scene_info['intrinsics'][id1][1].astype(np.float32) + + dist0 = np.array(scene_info['distortion_coefs'][id0][0], dtype=float) + dist1 = np.array(scene_info['distortion_coefs'][id1][1], dtype=float) + # Compute relative pose + T0 = scene_info['poses'][id0] + T1 = scene_info['poses'][id1] + T_0to1 = np.matmul(T1, np.linalg.inv(T0)) + pairs.append({'im0':im0, 'im1':im1, 'dist0':dist0, 'dist1':dist1, + 'K0':K0, 'K1':K1, 'T_0to1':T_0to1}) + scene_pairs[name] = pairs + + print(f"Loaded {total_pairs} pairs.") + return scene_pairs + + + +def save_matching_figure(path, img0, img1, mkpts0, mkpts1, inlier_mask, T_0to1, K0, K1, t_err=None, R_err=None, name=None, conf_thr = 5e-4): + """ Make and save matching figures + """ + Tx = np.cross(np.eye(3), T_0to1[:3, 3]) + E_mat = Tx @ T_0to1[:3, :3] + mkpts0_inliers = mkpts0[inlier_mask] + mkpts1_inliers = mkpts1[inlier_mask] + if inlier_mask is not None and len(inlier_mask) != 0: + epi_errs = symmetric_epipolar_distance_numpy(mkpts0_inliers, mkpts1_inliers, E_mat, K0, K1) + + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + + # matching info + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + text_precision =[ + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(mkpts0_inliers)}'] + else: + text_precision =[ + f'No inliers after ransac'] + + if name is not None: + text=[name] + else: + text = [] + + if t_err is not None and R_err is not None: + error_text = [f"err_t: {t_err:.2f} °", f"err_R: {R_err:.2f} °"] + text +=error_text + + text += text_precision + + # make the figure + figure = make_matching_figure(img0, img1, mkpts0_inliers, mkpts1_inliers, + color, text=text, path=path, dpi=150) + + +def aggregiate_scenes(scene_pose_auc, thresholds): + """Averages the auc results for cloudy_cloud and cloudy_sunny scenes + """ + temp_pose_auc = {} + for npz_name in scene_pose_auc.keys(): + scene_name = npz_name.split("_scene")[0] + temp_pose_auc[scene_name] = [np.zeros(len(thresholds), dtype=np.float32), 0] # [sum, total_number] + for npz_name in scene_pose_auc.keys(): + scene_name = npz_name.split("_scene")[0] + temp_pose_auc[scene_name][0] += scene_pose_auc[npz_name] + temp_pose_auc[scene_name][1] += 1 + + agg_pose_auc = {} + for scene_name in temp_pose_auc.keys(): + agg_pose_auc[scene_name] = temp_pose_auc[scene_name][0] / temp_pose_auc[scene_name][1] + + return agg_pose_auc + +def eval_relapose( + matcher, + data_root, + scene_pairs, + ransac_thres, + thresholds, + save_figs, + figures_dir=None, + method=None, + print_out=False, + debug=False, +): + scene_pose_auc = {} + for scene_name in scene_pairs.keys(): + scene_dir = osp.join(figures_dir, scene_name.split(".")[0]) + if save_figs and not osp.exists(scene_dir): + os.makedirs(scene_dir) + + pairs = scene_pairs[scene_name] + statis = defaultdict(list) + np.set_printoptions(precision=2) + + # Eval on pairs + print(f"\nStart evaluation on VisTir \n") + for i, pair in tqdm(enumerate(pairs), smoothing=.1, total=len(pairs)): + if debug and i > 10: + break + + T_0to1 = pair['T_0to1'] + im0 = str(data_root / pair['im0']) + im1 = str(data_root / pair['im1']) + match_res = matcher(im0, im1, pair['K0'], pair['K1'], pair['dist0'], pair['dist1']) + matches = match_res['matches'] + new_K0 = match_res['new_K0'] + new_K1 = match_res['new_K1'] + mkpts0 = match_res['mkpts0'] + mkpts1 = match_res['mkpts1'] + + # Calculate pose errors + ret = estimate_pose( + mkpts0, mkpts1, new_K0, new_K1, thresh=ransac_thres + ) + + if ret is None: + R, t, inliers = None, None, None + t_err, R_err = np.inf, np.inf + statis['failed'].append(i) + statis['R_errs'].append(R_err) + statis['t_errs'].append(t_err) + statis['inliers'].append(np.array([]).astype(np.bool_)) + else: + R, t, inliers = ret + t_err, R_err = relative_pose_error(T_0to1, R, t) + statis['R_errs'].append(R_err) + statis['t_errs'].append(t_err) + statis['inliers'].append(inliers.sum() / len(mkpts0)) + if print_out: + print(f"#M={len(matches)} R={R_err:.3f}, t={t_err:.3f}") + + if save_figs: + img0_name = f"{'vis' if 'visible' in pair['im0'] else 'tir'}_{osp.basename(pair['im0']).split('.')[0]}" + img1_name = f"{'vis' if 'visible' in pair['im1'] else 'tir'}_{osp.basename(pair['im1']).split('.')[0]}" + fig_path = osp.join(scene_dir, f"{img0_name}_{img1_name}.jpg") + save_matching_figure(path=fig_path, + img0=match_res['img0_undistorted'] if 'img0_undistorted' in match_res.keys() else match_res['img0'], + img1=match_res['img1_undistorted'] if 'img1_undistorted' in match_res.keys() else match_res['img1'], + mkpts0=mkpts0, + mkpts1=mkpts1, + inlier_mask=inliers, + T_0to1=T_0to1, + K0=new_K0, + K1=new_K1, + t_err=t_err, + R_err=R_err, + name=method + ) + + print(f"Scene: {scene_name} Total samples: {len(pairs)} Failed:{len(statis['failed'])}. \n") + pose_errors = np.max(np.stack([statis['R_errs'], statis['t_errs']]), axis=0) + pose_auc = error_auc(pose_errors, thresholds) # (auc@5, auc@10, auc@20) + scene_pose_auc[scene_name] = 100 * np.array([pose_auc[f'auc@{t}'] for t in thresholds]) + print(f"{scene_name} {pose_auc}") + agg_pose_auc = aggregiate_scenes(scene_pose_auc, thresholds) + return scene_pose_auc, agg_pose_auc + +def test_relative_pose_vistir( + data_root_dir, + method="xoftr", + exp_name = "VisTIR", + ransac_thres=1.5, + print_out=False, + save_dir=None, + save_figs=False, + debug=False, + args=None + +): + if not osp.exists(osp.join(save_dir, method)): + os.makedirs(osp.join(save_dir, method)) + + counter = 0 + path = osp.join(save_dir, method, f"{exp_name}"+"_{}") + while osp.exists(path.format(counter)): + counter += 1 + exp_dir = path.format(counter) + os.mkdir(exp_dir) + results_file = osp.join(exp_dir, "results.json") + figures_dir = osp.join(exp_dir, "match_figures") + if save_figs: + os.mkdir(figures_dir) + + # Init paths + npz_root = data_root_dir / 'index/scene_info_test/' + npz_list = data_root_dir / 'index/val_test_list/test_list.txt' + data_root = data_root_dir + + # Load pairs + scene_pairs = load_vis_tir_pairs_npz(npz_root, npz_list) + + # Load method + matcher = eval(f"load_{method}")(args) + + thresholds=[5, 10, 20] + # Eval + scene_pose_auc, agg_pose_auc = eval_relapose( + matcher, + data_root, + scene_pairs, + ransac_thres=ransac_thres, + thresholds=thresholds, + save_figs=save_figs, + figures_dir=figures_dir, + method=method, + print_out=print_out, + debug=debug, + ) + + # Create result dict + results = OrderedDict({"method": method, + "exp_name": exp_name, + "ransac_thres": ransac_thres, + "auc_thresholds": thresholds}) + results.update({key:value for key, value in vars(args).items() if key not in results}) + results.update({key:value.tolist() for key, value in agg_pose_auc.items()}) + results.update({key:value.tolist() for key, value in scene_pose_auc.items()}) + + print(f"Results: {json.dumps(results, indent=4)}") + + # Save to json file + with open(results_file, 'w') as outfile: + json.dump(results, outfile, indent=4) + + print(f"Results saved to {results_file}") + +if __name__ == '__main__': + + def add_common_arguments(parser): + parser.add_argument('--gpu', '-gpu', type=str, default='0') + parser.add_argument('--exp_name', type=str, default="VisTIR") + parser.add_argument('--data_root_dir', type=str, default="./data/METU_VisTIR/") + parser.add_argument('--save_dir', type=str, default="./results_relative_pose") + parser.add_argument('--ransac_thres', type=float, default=1.5) + parser.add_argument('--print_out', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--save_figs', action='store_true') + + def add_xoftr_arguments(subparsers): + subcommand = subparsers.add_parser('xoftr') + subcommand.add_argument('--match_threshold', type=float, default=0.3) + subcommand.add_argument('--fine_threshold', type=float, default=0.1) + subcommand.add_argument('--ckpt', type=str, default="./weights/weights_xoftr_640.ckpt") + add_common_arguments(subcommand) + + parser = argparse.ArgumentParser(description='Benchmark Relative Pose') + add_common_arguments(parser) + + # Create subparsers for top-level commands + subparsers = parser.add_subparsers(dest="method") + add_xoftr_arguments(subparsers) + + args = parser.parse_args() + + os.environ['CUDA_VISIBLE_DEVICES'] = "0" + tt = time.time() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + test_relative_pose_vistir( + Path(args.data_root_dir), + args.method, + args.exp_name, + ransac_thres=args.ransac_thres, + print_out=args.print_out, + save_dir = args.save_dir, + save_figs = args.save_figs, + debug=args.debug, + args=args + ) + print(f"Elapsed time: {time.time() - tt}") diff --git a/imcui/third_party/XoFTR/train.py b/imcui/third_party/XoFTR/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4fdb26d3857f276691dce65be6ad817e36db7414 --- /dev/null +++ b/imcui/third_party/XoFTR/train.py @@ -0,0 +1,126 @@ +import math +import argparse +import pprint +from distutils.util import strtobool +from pathlib import Path +from loguru import logger as loguru_logger +from datetime import datetime + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.plugins import DDPPlugin + +from src.config.default import get_cfg_defaults +from src.utils.misc import get_rank_zero_only_logger, setup_gpus +from src.utils.profiler import build_profiler +from src.lightning.data import MultiSceneDataModule +from src.lightning.lightning_xoftr import PL_XoFTR + +loguru_logger = get_rank_zero_only_logger(loguru_logger) + + +def parse_args(): + # init a costum parser which will be added into pl.Trainer parser + # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'data_cfg_path', type=str, help='data config path') + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--exp_name', type=str, default='default_exp_name') + parser.add_argument( + '--batch_size', type=int, default=4, help='batch_size per gpu') + parser.add_argument( + '--num_workers', type=int, default=4) + parser.add_argument( + '--pin_memory', type=lambda x: bool(strtobool(x)), + nargs='?', default=True, help='whether loading data to pinned memory or not') + parser.add_argument( + '--ckpt_path', type=str, default=None, + help='pretrained checkpoint path, helpful for using a pre-trained coarse-only XoFTR') + parser.add_argument( + '--disable_ckpt', action='store_true', + help='disable checkpoint saving (useful for debugging).') + parser.add_argument( + '--profiler_name', type=str, default=None, + help='options: [inference, pytorch], or leave it unset') + parser.add_argument( + '--parallel_load_data', action='store_true', + help='load datasets in with multiple processes.') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +def main(): + # parse arguments + args = parse_args() + rank_zero_only(pprint.pprint)(vars(args)) + + # init default-cfg and merge it with the main- and data-cfg + config = get_cfg_defaults() + config.merge_from_file(args.main_cfg_path) + config.merge_from_file(args.data_cfg_path) + pl.seed_everything(config.TRAINER.SEED) # reproducibility + + # scale lr and warmup-step automatically + args.gpus = _n_gpus = setup_gpus(args.gpus) + config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes + config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size + _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS + config.TRAINER.SCALING = _scaling + config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling + config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling) + + + # lightning module + profiler = build_profiler(args.profiler_name) + model = PL_XoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) + loguru_logger.info(f"XoFTR LightningModule initialized!") + + # lightning data + data_module = MultiSceneDataModule(args, config) + loguru_logger.info(f"XoFTR DataModule initialized!") + + # TensorBoard Logger + logger = [TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)] + ckpt_dir = Path(logger[0].log_dir) / 'checkpoints' + if config.TRAINER.USE_WANDB: + logger.append(WandbLogger(name=args.exp_name + f"_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + project='XoFTR')) + + # Callbacks + # TODO: update ModelCheckpoint to monitor multiple metrics + ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=-1, mode='max', + save_last=True, + dirpath=str(ckpt_dir), + filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}') + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks = [lr_monitor] + if not args.disable_ckpt: + callbacks.append(ckpt_callback) + + # Lightning Trainer + trainer = pl.Trainer.from_argparse_args( + args, + plugins=DDPPlugin(find_unused_parameters=True, + num_nodes=args.num_nodes, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), + gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, + callbacks=callbacks, + logger=logger, + sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, + replace_sampler_ddp=False, # use custom sampler + reload_dataloaders_every_epoch=False, # avoid repeated samples! + weights_summary='full', + profiler=profiler) + loguru_logger.info(f"Trainer initialized!") + loguru_logger.info(f"Start training!") + trainer.fit(model, datamodule=data_module) + + +if __name__ == '__main__': + main() diff --git a/imcui/third_party/d2net/extract_features.py b/imcui/third_party/d2net/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..628463a7d042a90b5cadea8a317237cde86f5ae4 --- /dev/null +++ b/imcui/third_party/d2net/extract_features.py @@ -0,0 +1,156 @@ +import argparse + +import numpy as np + +import imageio + +import torch + +from tqdm import tqdm + +import scipy +import scipy.io +import scipy.misc + +from lib.model_test import D2Net +from lib.utils import preprocess_image +from lib.pyramid import process_multiscale + +# CUDA +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") + +# Argument parsing +parser = argparse.ArgumentParser(description='Feature extraction script') + +parser.add_argument( + '--image_list_file', type=str, required=True, + help='path to a file containing a list of images to process' +) + +parser.add_argument( + '--preprocessing', type=str, default='caffe', + help='image preprocessing (caffe or torch)' +) +parser.add_argument( + '--model_file', type=str, default='models/d2_tf.pth', + help='path to the full model' +) + +parser.add_argument( + '--max_edge', type=int, default=1600, + help='maximum image size at network input' +) +parser.add_argument( + '--max_sum_edges', type=int, default=2800, + help='maximum sum of image sizes at network input' +) + +parser.add_argument( + '--output_extension', type=str, default='.d2-net', + help='extension for the output' +) +parser.add_argument( + '--output_type', type=str, default='npz', + help='output file type (npz or mat)' +) + +parser.add_argument( + '--multiscale', dest='multiscale', action='store_true', + help='extract multiscale features' +) +parser.set_defaults(multiscale=False) + +parser.add_argument( + '--no-relu', dest='use_relu', action='store_false', + help='remove ReLU after the dense feature extraction module' +) +parser.set_defaults(use_relu=True) + +args = parser.parse_args() + +print(args) + +# Creating CNN model +model = D2Net( + model_file=args.model_file, + use_relu=args.use_relu, + use_cuda=use_cuda +) + +# Process the file +with open(args.image_list_file, 'r') as f: + lines = f.readlines() +for line in tqdm(lines, total=len(lines)): + path = line.strip() + + image = imageio.imread(path) + if len(image.shape) == 2: + image = image[:, :, np.newaxis] + image = np.repeat(image, 3, -1) + + # TODO: switch to PIL.Image due to deprecation of scipy.misc.imresize. + resized_image = image + if max(resized_image.shape) > args.max_edge: + resized_image = scipy.misc.imresize( + resized_image, + args.max_edge / max(resized_image.shape) + ).astype('float') + if sum(resized_image.shape[: 2]) > args.max_sum_edges: + resized_image = scipy.misc.imresize( + resized_image, + args.max_sum_edges / sum(resized_image.shape[: 2]) + ).astype('float') + + fact_i = image.shape[0] / resized_image.shape[0] + fact_j = image.shape[1] / resized_image.shape[1] + + input_image = preprocess_image( + resized_image, + preprocessing=args.preprocessing + ) + with torch.no_grad(): + if args.multiscale: + keypoints, scores, descriptors = process_multiscale( + torch.tensor( + input_image[np.newaxis, :, :, :].astype(np.float32), + device=device + ), + model + ) + else: + keypoints, scores, descriptors = process_multiscale( + torch.tensor( + input_image[np.newaxis, :, :, :].astype(np.float32), + device=device + ), + model, + scales=[1] + ) + + # Input image coordinates + keypoints[:, 0] *= fact_i + keypoints[:, 1] *= fact_j + # i, j -> u, v + keypoints = keypoints[:, [1, 0, 2]] + + if args.output_type == 'npz': + with open(path + args.output_extension, 'wb') as output_file: + np.savez( + output_file, + keypoints=keypoints, + scores=scores, + descriptors=descriptors + ) + elif args.output_type == 'mat': + with open(path + args.output_extension, 'wb') as output_file: + scipy.io.savemat( + output_file, + { + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors + } + ) + else: + raise ValueError('Unknown output type.') diff --git a/imcui/third_party/d2net/extract_kapture.py b/imcui/third_party/d2net/extract_kapture.py new file mode 100644 index 0000000000000000000000000000000000000000..23198b978229c699dbe24cd3bc0400d62bcab030 --- /dev/null +++ b/imcui/third_party/d2net/extract_kapture.py @@ -0,0 +1,248 @@ +import argparse +import numpy as np +from PIL import Image +import torch +import math +from tqdm import tqdm +from os import path + +# Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) and more generally sensor-acquired data +# it can be installed with +# pip install kapture +# for more information check out https://github.com/naver/kapture +import kapture +from kapture.io.records import get_image_fullpath +from kapture.io.csv import kapture_from_dir, get_all_tar_handlers +from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file +from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file +from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file + +from lib.model_test import D2Net +from lib.utils import preprocess_image +from lib.pyramid import process_multiscale + +# import imageio + +# CUDA +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") + +# Argument parsing +parser = argparse.ArgumentParser(description='Feature extraction script') + +parser.add_argument( + '--kapture-root', type=str, required=True, + help='path to kapture root directory' +) + +parser.add_argument( + '--preprocessing', type=str, default='caffe', + help='image preprocessing (caffe or torch)' +) +parser.add_argument( + '--model_file', type=str, default='models/d2_tf.pth', + help='path to the full model' +) +parser.add_argument( + '--keypoints-type', type=str, default=None, + help='keypoint type_name, default is filename of model' +) +parser.add_argument( + '--descriptors-type', type=str, default=None, + help='descriptors type_name, default is filename of model' +) + +parser.add_argument( + '--max_edge', type=int, default=1600, + help='maximum image size at network input' +) +parser.add_argument( + '--max_sum_edges', type=int, default=2800, + help='maximum sum of image sizes at network input' +) + +parser.add_argument( + '--multiscale', dest='multiscale', action='store_true', + help='extract multiscale features' +) +parser.set_defaults(multiscale=False) + +parser.add_argument( + '--no-relu', dest='use_relu', action='store_false', + help='remove ReLU after the dense feature extraction module' +) +parser.set_defaults(use_relu=True) + +parser.add_argument("--max-keypoints", type=int, default=float("+inf"), + help='max number of keypoints save to disk') + +args = parser.parse_args() + +print(args) +with get_all_tar_handlers(args.kapture_root, + mode={kapture.Keypoints: 'a', + kapture.Descriptors: 'a', + kapture.GlobalFeatures: 'r', + kapture.Matches: 'r'}) as tar_handlers: + kdata = kapture_from_dir(args.kapture_root, + skip_list=[kapture.GlobalFeatures, + kapture.Matches, + kapture.Points3d, + kapture.Observations], + tar_handlers=tar_handlers) + if kdata.keypoints is None: + kdata.keypoints = {} + if kdata.descriptors is None: + kdata.descriptors = {} + + assert kdata.records_camera is not None + image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)] + if args.keypoints_type is None: + args.keypoints_type = path.splitext(path.basename(args.model_file))[0] + print(f'keypoints_type set to {args.keypoints_type}') + if args.descriptors_type is None: + args.descriptors_type = path.splitext(path.basename(args.model_file))[0] + print(f'descriptors_type set to {args.descriptors_type}') + if args.keypoints_type in kdata.keypoints and args.descriptors_type in kdata.descriptors: + image_list = [name + for name in image_list + if name not in kdata.keypoints[args.keypoints_type] or + name not in kdata.descriptors[args.descriptors_type]] + + if len(image_list) == 0: + print('All features were already extracted') + exit(0) + else: + print(f'Extracting d2net features for {len(image_list)} images') + + # Creating CNN model + model = D2Net( + model_file=args.model_file, + use_relu=args.use_relu, + use_cuda=use_cuda + ) + + if args.keypoints_type not in kdata.keypoints: + keypoints_dtype = None + keypoints_dsize = None + else: + keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype + keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize + if args.descriptors_type not in kdata.descriptors: + descriptors_dtype = None + descriptors_dsize = None + else: + descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype + descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize + + # Process the files + for image_name in tqdm(image_list, total=len(image_list)): + img_path = get_image_fullpath(args.kapture_root, image_name) + image = Image.open(img_path).convert('RGB') + + width, height = image.size + + resized_image = image + resized_width = width + resized_height = height + + max_edge = args.max_edge + max_sum_edges = args.max_sum_edges + if max(resized_width, resized_height) > max_edge: + scale_multiplier = max_edge / max(resized_width, resized_height) + resized_width = math.floor(resized_width * scale_multiplier) + resized_height = math.floor(resized_height * scale_multiplier) + resized_image = image.resize((resized_width, resized_height)) + if resized_width + resized_height > max_sum_edges: + scale_multiplier = max_sum_edges / (resized_width + resized_height) + resized_width = math.floor(resized_width * scale_multiplier) + resized_height = math.floor(resized_height * scale_multiplier) + resized_image = image.resize((resized_width, resized_height)) + + fact_i = width / resized_width + fact_j = height / resized_height + + resized_image = np.array(resized_image).astype('float') + + input_image = preprocess_image( + resized_image, + preprocessing=args.preprocessing + ) + + with torch.no_grad(): + if args.multiscale: + keypoints, scores, descriptors = process_multiscale( + torch.tensor( + input_image[np.newaxis, :, :, :].astype(np.float32), + device=device + ), + model + ) + else: + keypoints, scores, descriptors = process_multiscale( + torch.tensor( + input_image[np.newaxis, :, :, :].astype(np.float32), + device=device + ), + model, + scales=[1] + ) + + # Input image coordinates + keypoints[:, 0] *= fact_i + keypoints[:, 1] *= fact_j + # i, j -> u, v + keypoints = keypoints[:, [1, 0, 2]] + + if args.max_keypoints != float("+inf"): + # keep the last (the highest) indexes + idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints):] + keypoints = keypoints[idx_keep] + descriptors = descriptors[idx_keep] + + if keypoints_dtype is None or descriptors_dtype is None: + keypoints_dtype = keypoints.dtype + descriptors_dtype = descriptors.dtype + + keypoints_dsize = keypoints.shape[1] + descriptors_dsize = descriptors.shape[1] + + kdata.keypoints[args.keypoints_type] = kapture.Keypoints('d2net', keypoints_dtype, keypoints_dsize) + kdata.descriptors[args.descriptors_type] = kapture.Descriptors('d2net', descriptors_dtype, + descriptors_dsize, + args.keypoints_type, 'L2') + + keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints, + args.keypoints_type, + args.kapture_root) + descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors, + args.descriptors_type, + args.kapture_root) + + keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]) + descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type]) + else: + assert kdata.keypoints[args.keypoints_type].dtype == keypoints.dtype + assert kdata.descriptors[args.descriptors_type].dtype == descriptors.dtype + assert kdata.keypoints[args.keypoints_type].dsize == keypoints.shape[1] + assert kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1] + assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type + assert kdata.descriptors[args.descriptors_type].metric_type == 'L2' + + keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root, + image_name, tar_handlers) + print(f"Saving {keypoints.shape[0]} keypoints to {keypoints_fullpath}") + image_keypoints_to_file(keypoints_fullpath, keypoints) + kdata.keypoints[args.keypoints_type].add(image_name) + + descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root, + image_name, tar_handlers) + print(f"Saving {descriptors.shape[0]} descriptors to {descriptors_fullpath}") + image_descriptors_to_file(descriptors_fullpath, descriptors) + kdata.descriptors[args.descriptors_type].add(image_name) + + if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type, + args.kapture_root, tar_handlers) or \ + not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type, + args.kapture_root, tar_handlers): + print('local feature extraction ended successfully but not all files were saved') diff --git a/imcui/third_party/d2net/megadepth_utils/preprocess_scene.py b/imcui/third_party/d2net/megadepth_utils/preprocess_scene.py new file mode 100644 index 0000000000000000000000000000000000000000..fc68a403795e7cddce88dfcb74b38d19ab09e133 --- /dev/null +++ b/imcui/third_party/d2net/megadepth_utils/preprocess_scene.py @@ -0,0 +1,242 @@ +import argparse + +import imagesize + +import numpy as np + +import os + +parser = argparse.ArgumentParser(description='MegaDepth preprocessing script') + +parser.add_argument( + '--base_path', type=str, required=True, + help='path to MegaDepth' +) +parser.add_argument( + '--scene_id', type=str, required=True, + help='scene ID' +) + +parser.add_argument( + '--output_path', type=str, required=True, + help='path to the output directory' +) + +args = parser.parse_args() + +base_path = args.base_path +# Remove the trailing / if need be. +if base_path[-1] in ['/', '\\']: + base_path = base_path[: - 1] +scene_id = args.scene_id + +base_depth_path = os.path.join( + base_path, 'phoenix/S6/zl548/MegaDepth_v1' +) +base_undistorted_sfm_path = os.path.join( + base_path, 'Undistorted_SfM' +) + +undistorted_sparse_path = os.path.join( + base_undistorted_sfm_path, scene_id, 'sparse-txt' +) +if not os.path.exists(undistorted_sparse_path): + exit() + +depths_path = os.path.join( + base_depth_path, scene_id, 'dense0', 'depths' +) +if not os.path.exists(depths_path): + exit() + +images_path = os.path.join( + base_undistorted_sfm_path, scene_id, 'images' +) +if not os.path.exists(images_path): + exit() + +# Process cameras.txt +with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f: + raw = f.readlines()[3 :] # skip the header + +camera_intrinsics = {} +for camera in raw: + camera = camera.split(' ') + camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]] + +# Process points3D.txt +with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f: + raw = f.readlines()[3 :] # skip the header + +points3D = {} +for point3D in raw: + point3D = point3D.split(' ') + points3D[int(point3D[0])] = np.array([ + float(point3D[1]), float(point3D[2]), float(point3D[3]) + ]) + +# Process images.txt +with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f: + raw = f.readlines()[4 :] # skip the header + +image_id_to_idx = {} +image_names = [] +raw_pose = [] +camera = [] +points3D_id_to_2D = [] +n_points3D = [] +for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])): + image = image.split(' ') + points = points.split(' ') + + image_id_to_idx[int(image[0])] = idx + + image_name = image[-1].strip('\n') + image_names.append(image_name) + + raw_pose.append([float(elem) for elem in image[1 : -2]]) + camera.append(int(image[-2])) + current_points3D_id_to_2D = {} + for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]): + if int(point3D_id) == -1: + continue + current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)] + points3D_id_to_2D.append(current_points3D_id_to_2D) + n_points3D.append(len(current_points3D_id_to_2D)) +n_images = len(image_names) + +# Image and depthmaps paths +image_paths = [] +depth_paths = [] +for image_name in image_names: + image_path = os.path.join(images_path, image_name) + + # Path to the depth file + depth_path = os.path.join( + depths_path, '%s.h5' % os.path.splitext(image_name)[0] + ) + + if os.path.exists(depth_path): + # Check if depth map or background / foreground mask + file_size = os.stat(depth_path).st_size + # Rough estimate - 75KB might work as well + if file_size < 100 * 1024: + depth_paths.append(None) + image_paths.append(None) + else: + depth_paths.append(depth_path[len(base_path) + 1 :]) + image_paths.append(image_path[len(base_path) + 1 :]) + else: + depth_paths.append(None) + image_paths.append(None) + +# Camera configuration +intrinsics = [] +poses = [] +principal_axis = [] +points3D_id_to_ndepth = [] +for idx, image_name in enumerate(image_names): + if image_paths[idx] is None: + intrinsics.append(None) + poses.append(None) + principal_axis.append([0, 0, 0]) + points3D_id_to_ndepth.append({}) + continue + image_intrinsics = camera_intrinsics[camera[idx]] + K = np.zeros([3, 3]) + K[0, 0] = image_intrinsics[2] + K[0, 2] = image_intrinsics[4] + K[1, 1] = image_intrinsics[3] + K[1, 2] = image_intrinsics[5] + K[2, 2] = 1 + intrinsics.append(K) + + image_pose = raw_pose[idx] + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array([ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ] + ]) + principal_axis.append(R[2, :]) + t = image_pose[4 : 7] + # World-to-Camera pose + current_pose = np.zeros([4, 4]) + current_pose[: 3, : 3] = R + current_pose[: 3, 3] = t + current_pose[3, 3] = 1 + # Camera-to-World pose + # pose = np.zeros([4, 4]) + # pose[: 3, : 3] = np.transpose(R) + # pose[: 3, 3] = -np.matmul(np.transpose(R), t) + # pose[3, 3] = 1 + poses.append(current_pose) + + current_points3D_id_to_ndepth = {} + for point3D_id in points3D_id_to_2D[idx].keys(): + p3d = points3D[point3D_id] + current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1])) + points3D_id_to_ndepth.append(current_points3D_id_to_ndepth) +principal_axis = np.array(principal_axis) +angles = np.rad2deg(np.arccos( + np.clip( + np.dot(principal_axis, np.transpose(principal_axis)), + -1, 1 + ) +)) + +# Compute overlap score +overlap_matrix = np.full([n_images, n_images], -1.) +scale_ratio_matrix = np.full([n_images, n_images], -1.) +for idx1 in range(n_images): + if image_paths[idx1] is None or depth_paths[idx1] is None: + continue + for idx2 in range(idx1 + 1, n_images): + if image_paths[idx2] is None or depth_paths[idx2] is None: + continue + matches = ( + points3D_id_to_2D[idx1].keys() & + points3D_id_to_2D[idx2].keys() + ) + min_num_points3D = min( + len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2]) + ) + overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D + overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D + if len(matches) == 0: + continue + points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1] + points3D_id_to_ndepth2 = points3D_id_to_ndepth[idx2] + nd1 = np.array([points3D_id_to_ndepth1[match] for match in matches]) + nd2 = np.array([points3D_id_to_ndepth2[match] for match in matches]) + min_scale_ratio = np.min(np.maximum(nd1 / nd2, nd2 / nd1)) + scale_ratio_matrix[idx1, idx2] = min_scale_ratio + scale_ratio_matrix[idx2, idx1] = min_scale_ratio + +np.savez( + os.path.join(args.output_path, '%s.npz' % scene_id), + image_paths=image_paths, + depth_paths=depth_paths, + intrinsics=intrinsics, + poses=poses, + overlap_matrix=overlap_matrix, + scale_ratio_matrix=scale_ratio_matrix, + angles=angles, + n_points3D=n_points3D, + points3D_id_to_2D=points3D_id_to_2D, + points3D_id_to_ndepth=points3D_id_to_ndepth +) diff --git a/imcui/third_party/d2net/megadepth_utils/undistort_reconstructions.py b/imcui/third_party/d2net/megadepth_utils/undistort_reconstructions.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b99a72f81206e6fbefae9daa9aa683c8754051 --- /dev/null +++ b/imcui/third_party/d2net/megadepth_utils/undistort_reconstructions.py @@ -0,0 +1,69 @@ +import argparse + +import imagesize + +import os + +import subprocess + +parser = argparse.ArgumentParser(description='MegaDepth Undistortion') + +parser.add_argument( + '--colmap_path', type=str, required=True, + help='path to colmap executable' +) +parser.add_argument( + '--base_path', type=str, required=True, + help='path to MegaDepth' +) + +args = parser.parse_args() + +sfm_path = os.path.join( + args.base_path, 'MegaDepth_v1_SfM' +) +base_depth_path = os.path.join( + args.base_path, 'phoenix/S6/zl548/MegaDepth_v1' +) +output_path = os.path.join( + args.base_path, 'Undistorted_SfM' +) + +os.mkdir(output_path) + +for scene_name in os.listdir(base_depth_path): + current_output_path = os.path.join(output_path, scene_name) + os.mkdir(current_output_path) + + image_path = os.path.join( + base_depth_path, scene_name, 'dense0', 'imgs' + ) + if not os.path.exists(image_path): + continue + + # Find the maximum image size in scene. + max_image_size = 0 + for image_name in os.listdir(image_path): + max_image_size = max( + max_image_size, + max(imagesize.get(os.path.join(image_path, image_name))) + ) + + # Undistort the images and update the reconstruction. + subprocess.call([ + os.path.join(args.colmap_path, 'colmap'), 'image_undistorter', + '--image_path', os.path.join(sfm_path, scene_name, 'images'), + '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'), + '--output_path', current_output_path, + '--max_image_size', str(max_image_size) + ]) + + # Transform the reconstruction to raw text format. + sparse_txt_path = os.path.join(current_output_path, 'sparse-txt') + os.mkdir(sparse_txt_path) + subprocess.call([ + os.path.join(args.colmap_path, 'colmap'), 'model_converter', + '--input_path', os.path.join(current_output_path, 'sparse'), + '--output_path', sparse_txt_path, + '--output_type', 'TXT' + ]) \ No newline at end of file diff --git a/imcui/third_party/d2net/train.py b/imcui/third_party/d2net/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5817f1712bda0779175fb18437d1f8c263f29f3b --- /dev/null +++ b/imcui/third_party/d2net/train.py @@ -0,0 +1,279 @@ +import argparse + +import numpy as np + +import os + +import shutil + +import torch +import torch.optim as optim + +from torch.utils.data import DataLoader + +from tqdm import tqdm + +import warnings + +from lib.dataset import MegaDepthDataset +from lib.exceptions import NoGradientError +from lib.loss import loss_function +from lib.model import D2Net + + +# CUDA +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") + +# Seed +torch.manual_seed(1) +if use_cuda: + torch.cuda.manual_seed(1) +np.random.seed(1) + +# Argument parsing +parser = argparse.ArgumentParser(description='Training script') + +parser.add_argument( + '--dataset_path', type=str, required=True, + help='path to the dataset' +) +parser.add_argument( + '--scene_info_path', type=str, required=True, + help='path to the processed scenes' +) + +parser.add_argument( + '--preprocessing', type=str, default='caffe', + help='image preprocessing (caffe or torch)' +) +parser.add_argument( + '--model_file', type=str, default='models/d2_ots.pth', + help='path to the full model' +) + +parser.add_argument( + '--num_epochs', type=int, default=10, + help='number of training epochs' +) +parser.add_argument( + '--lr', type=float, default=1e-3, + help='initial learning rate' +) +parser.add_argument( + '--batch_size', type=int, default=1, + help='batch size' +) +parser.add_argument( + '--num_workers', type=int, default=4, + help='number of workers for data loading' +) + +parser.add_argument( + '--use_validation', dest='use_validation', action='store_true', + help='use the validation split' +) +parser.set_defaults(use_validation=False) + +parser.add_argument( + '--log_interval', type=int, default=250, + help='loss logging interval' +) + +parser.add_argument( + '--log_file', type=str, default='log.txt', + help='loss logging file' +) + +parser.add_argument( + '--plot', dest='plot', action='store_true', + help='plot training pairs' +) +parser.set_defaults(plot=False) + +parser.add_argument( + '--checkpoint_directory', type=str, default='checkpoints', + help='directory for training checkpoints' +) +parser.add_argument( + '--checkpoint_prefix', type=str, default='d2', + help='prefix for training checkpoints' +) + +args = parser.parse_args() + +print(args) + +# Create the folders for plotting if need be +if args.plot: + plot_path = 'train_vis' + if os.path.isdir(plot_path): + print('[Warning] Plotting directory already exists.') + else: + os.mkdir(plot_path) + +# Creating CNN model +model = D2Net( + model_file=args.model_file, + use_cuda=use_cuda +) + +# Optimizer +optimizer = optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr +) + +# Dataset +if args.use_validation: + validation_dataset = MegaDepthDataset( + scene_list_path='megadepth_utils/valid_scenes.txt', + scene_info_path=args.scene_info_path, + base_path=args.dataset_path, + train=False, + preprocessing=args.preprocessing, + pairs_per_scene=25 + ) + validation_dataloader = DataLoader( + validation_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers + ) + +training_dataset = MegaDepthDataset( + scene_list_path='megadepth_utils/train_scenes.txt', + scene_info_path=args.scene_info_path, + base_path=args.dataset_path, + preprocessing=args.preprocessing +) +training_dataloader = DataLoader( + training_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers +) + + +# Define epoch function +def process_epoch( + epoch_idx, + model, loss_function, optimizer, dataloader, device, + log_file, args, train=True +): + epoch_losses = [] + + torch.set_grad_enabled(train) + + progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) + for batch_idx, batch in progress_bar: + if train: + optimizer.zero_grad() + + batch['train'] = train + batch['epoch_idx'] = epoch_idx + batch['batch_idx'] = batch_idx + batch['batch_size'] = args.batch_size + batch['preprocessing'] = args.preprocessing + batch['log_interval'] = args.log_interval + + try: + loss = loss_function(model, batch, device, plot=args.plot) + except NoGradientError: + continue + + current_loss = loss.data.cpu().numpy()[0] + epoch_losses.append(current_loss) + + progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses))) + + if batch_idx % args.log_interval == 0: + log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % ( + 'train' if train else 'valid', + epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses) + )) + + if train: + loss.backward() + optimizer.step() + + log_file.write('[%s] epoch %d - avg_loss: %f\n' % ( + 'train' if train else 'valid', + epoch_idx, + np.mean(epoch_losses) + )) + log_file.flush() + + return np.mean(epoch_losses) + + +# Create the checkpoint directory +if os.path.isdir(args.checkpoint_directory): + print('[Warning] Checkpoint directory already exists.') +else: + os.mkdir(args.checkpoint_directory) + + +# Open the log file for writing +if os.path.exists(args.log_file): + print('[Warning] Log file already exists.') +log_file = open(args.log_file, 'a+') + +# Initialize the history +train_loss_history = [] +validation_loss_history = [] +if args.use_validation: + validation_dataset.build_dataset() + min_validation_loss = process_epoch( + 0, + model, loss_function, optimizer, validation_dataloader, device, + log_file, args, + train=False + ) + +# Start the training +for epoch_idx in range(1, args.num_epochs + 1): + # Process epoch + training_dataset.build_dataset() + train_loss_history.append( + process_epoch( + epoch_idx, + model, loss_function, optimizer, training_dataloader, device, + log_file, args + ) + ) + + if args.use_validation: + validation_loss_history.append( + process_epoch( + epoch_idx, + model, loss_function, optimizer, validation_dataloader, device, + log_file, args, + train=False + ) + ) + + # Save the current checkpoint + checkpoint_path = os.path.join( + args.checkpoint_directory, + '%s.%02d.pth' % (args.checkpoint_prefix, epoch_idx) + ) + checkpoint = { + 'args': args, + 'epoch_idx': epoch_idx, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'train_loss_history': train_loss_history, + 'validation_loss_history': validation_loss_history + } + torch.save(checkpoint, checkpoint_path) + if ( + args.use_validation and + validation_loss_history[-1] < min_validation_loss + ): + min_validation_loss = validation_loss_history[-1] + best_checkpoint_path = os.path.join( + args.checkpoint_directory, + '%s.best.pth' % args.checkpoint_prefix + ) + shutil.copy(checkpoint_path, best_checkpoint_path) + +# Close the log file +log_file.close() diff --git a/imcui/third_party/dust3r/croco/datasets/__init__.py b/imcui/third_party/dust3r/croco/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py b/imcui/third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py new file mode 100644 index 0000000000000000000000000000000000000000..eb66a0474ce44b54c44c08887cbafdb045b11ff3 --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py @@ -0,0 +1,159 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Extracting crops for pre-training +# -------------------------------------------------------- + +import os +import argparse +from tqdm import tqdm +from PIL import Image +import functools +from multiprocessing import Pool +import math + + +def arg_parser(): + parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list') + + parser.add_argument('--crops', type=str, required=True, help='crop file') + parser.add_argument('--root-dir', type=str, required=True, help='root directory') + parser.add_argument('--output-dir', type=str, required=True, help='output directory') + parser.add_argument('--imsize', type=int, default=256, help='size of the crops') + parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads') + parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories') + parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir') + return parser + + +def main(args): + listing_path = os.path.join(args.output_dir, 'listing.txt') + + print(f'Loading list of crops ... ({args.nthread} threads)') + crops, num_crops_to_generate = load_crop_file(args.crops) + + print(f'Preparing jobs ({len(crops)} candidate image pairs)...') + num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels) + num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels)) + + jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) + del crops + + os.makedirs(args.output_dir, exist_ok=True) + mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map + call = functools.partial(save_image_crops, args) + + print(f"Generating cropped images to {args.output_dir} ...") + with open(listing_path, 'w') as listing: + listing.write('# pair_path\n') + for results in tqdm(mmap(call, jobs), total=len(jobs)): + for path in results: + listing.write(f'{path}\n') + print('Finished writing listing to', listing_path) + + +def load_crop_file(path): + data = open(path).read().splitlines() + pairs = [] + num_crops_to_generate = 0 + for line in tqdm(data): + if line.startswith('#'): + continue + line = line.split(', ') + if len(line) < 8: + img1, img2, rotation = line + pairs.append((img1, img2, int(rotation), [])) + else: + l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) + rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) + pairs[-1][-1].append((rect1, rect2)) + num_crops_to_generate += 1 + return pairs, num_crops_to_generate + + +def prepare_jobs(pairs, num_levels, num_pairs_in_dir): + jobs = [] + powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))] + + def get_path(idx): + idx_array = [] + d = idx + for level in range(num_levels - 1): + idx_array.append(idx // powers[level]) + idx = idx % powers[level] + idx_array.append(d) + return '/'.join(map(lambda x: hex(x)[2:], idx_array)) + + idx = 0 + for pair_data in tqdm(pairs): + img1, img2, rotation, crops = pair_data + if -60 <= rotation and rotation <= 60: + rotation = 0 # most likely not a true rotation + paths = [get_path(idx + k) for k in range(len(crops))] + idx += len(crops) + jobs.append(((img1, img2), rotation, crops, paths)) + return jobs + + +def load_image(path): + try: + return Image.open(path).convert('RGB') + except Exception as e: + print('skipping', path, e) + raise OSError() + + +def save_image_crops(args, data): + # load images + img_pair, rot, crops, paths = data + try: + img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair] + except OSError as e: + return [] + + def area(sz): + return sz[0] * sz[1] + + tgt_size = (args.imsize, args.imsize) + + def prepare_crop(img, rect, rot=0): + # actual crop + img = img.crop(rect) + + # resize to desired size + interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC + img = img.resize(tgt_size, resample=interp) + + # rotate the image + rot90 = (round(rot/90) % 4) * 90 + if rot90 == 90: + img = img.transpose(Image.Transpose.ROTATE_90) + elif rot90 == 180: + img = img.transpose(Image.Transpose.ROTATE_180) + elif rot90 == 270: + img = img.transpose(Image.Transpose.ROTATE_270) + return img + + results = [] + for (rect1, rect2), path in zip(crops, paths): + crop1 = prepare_crop(img1, rect1) + crop2 = prepare_crop(img2, rect2, rot) + + fullpath1 = os.path.join(args.output_dir, path+'_1.jpg') + fullpath2 = os.path.join(args.output_dir, path+'_2.jpg') + os.makedirs(os.path.dirname(fullpath1), exist_ok=True) + + assert not os.path.isfile(fullpath1), fullpath1 + assert not os.path.isfile(fullpath2), fullpath2 + crop1.save(fullpath1) + crop2.save(fullpath2) + results.append(path) + + return results + + +if __name__ == '__main__': + args = arg_parser().parse_args() + main(args) + diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/__init__.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe0d399084359495250dc8184671ff498adfbf2 --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py @@ -0,0 +1,92 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script to generate image pairs for a given scene reproducing poses provided in a metadata file. +""" +import os +from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator +from datasets.habitat_sim.paths import SCENES_DATASET +import argparse +import quaternion +import PIL.Image +import cv2 +import json +from tqdm import tqdm + +def generate_multiview_images_from_metadata(metadata_filename, + output_dir, + overload_params = dict(), + scene_datasets_paths=None, + exist_ok=False): + """ + Generate images from a metadata file for reproducibility purposes. + """ + # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label + if scene_datasets_paths is not None: + scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True)) + + with open(metadata_filename, 'r') as f: + input_metadata = json.load(f) + metadata = dict() + for key, value in input_metadata.items(): + # Optionally replace some paths + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + if scene_datasets_paths is not None: + for dataset_label, dataset_path in scene_datasets_paths.items(): + if value.startswith(dataset_label): + value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label))) + break + metadata[key] = value + + # Overload some parameters + for key, value in overload_params.items(): + metadata[key] = value + + generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))]) + generate_depth = metadata["generate_depth"] + + os.makedirs(output_dir, exist_ok=exist_ok) + + generator = MultiviewHabitatSimGenerator(**generation_entries) + + # Generate views + for idx_label, data in tqdm(metadata['multiviews'].items()): + positions = data["positions"] + orientations = data["orientations"] + n = len(positions) + for oidx in range(n): + observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx])) + observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 + # Color image saved using PIL + img = PIL.Image.fromarray(observation['color'][:,:,:3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr") + cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json") + with open(filename, "w") as f: + json.dump(camera_params, f) + # Save metadata + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(metadata, f) + + generator.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_filename", required=True) + parser.add_argument("--output_dir", required=True) + args = parser.parse_args() + + generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename, + output_dir=args.output_dir, + scene_datasets_paths=SCENES_DATASET, + overload_params=dict(), + exist_ok=True) + + \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..962ef849d8c31397b8622df4f2d9140175d78873 --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py @@ -0,0 +1,27 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script generating commandlines to generate image pairs from metadata files. +""" +import os +import glob +from tqdm import tqdm +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.") + args = parser.parse_args() + + input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True) + + for metadata_filename in tqdm(input_metadata_filenames): + output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir)) + # Do not process the scene if the metadata file already exists + if os.path.exists(os.path.join(output_dir, "metadata.json")): + continue + commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}" + print(commandline) diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py new file mode 100644 index 0000000000000000000000000000000000000000..421d49a1696474415940493296b3f2d982398850 --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py @@ -0,0 +1,177 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +from tqdm import tqdm +import argparse +import PIL.Image +import numpy as np +import json +from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError +from datasets.habitat_sim.paths import list_scenes_available +import cv2 +import quaternion +import shutil + +def generate_multiview_images_for_scene(scene_dataset_config_file, + scene, + navmesh, + output_dir, + views_count, + size, + exist_ok=False, + generate_depth=False, + **kwargs): + """ + Generate tuples of overlapping views for a given scene. + generate_depth: generate depth images and camera parameters. + """ + if os.path.exists(output_dir) and not exist_ok: + print(f"Scene {scene}: data already generated. Ignoring generation.") + return + try: + print(f"Scene {scene}: {size} multiview acquisitions to generate...") + os.makedirs(output_dir, exist_ok=exist_ok) + + metadata_filename = os.path.join(output_dir, "metadata.json") + + metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count=views_count, + size=size, + generate_depth=generate_depth, + **kwargs) + metadata_template["multiviews"] = dict() + + if os.path.exists(metadata_filename): + print("Metadata file already exists:", metadata_filename) + print("Loading already generated metadata file...") + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + for key in metadata_template.keys(): + if key != "multiviews": + assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}." + else: + print("No temporary file found. Starting generation from scratch...") + metadata = metadata_template + + starting_id = len(metadata["multiviews"]) + print(f"Starting generation from index {starting_id}/{size}...") + if starting_id >= size: + print("Generation already done.") + return + + generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count = views_count, + size = size, + **kwargs) + + for idx in tqdm(range(starting_id, size)): + # Generate / re-generate the observations + try: + data = generator[idx] + observations = data["observations"] + positions = data["positions"] + orientations = data["orientations"] + + idx_label = f"{idx:08}" + for oidx, observation in enumerate(observations): + observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 + # Color image saved using PIL + img = PIL.Image.fromarray(observation['color'][:,:,:3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr") + cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json") + with open(filename, "w") as f: + json.dump(camera_params, f) + metadata["multiviews"][idx_label] = {"positions": positions.tolist(), + "orientations": orientations.tolist(), + "covisibility_ratios": data["covisibility_ratios"].tolist(), + "valid_fractions": data["valid_fractions"].tolist(), + "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()} + except RecursionError: + print("Recursion error: unable to sample observations for this scene. We will stop there.") + break + + # Regularly save a temporary metadata file, in case we need to restart the generation + if idx % 10 == 0: + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + # Save metadata + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + generator.close() + except NoNaviguableSpaceError: + pass + +def create_commandline(scene_data, generate_depth, exist_ok=False): + """ + Create a commandline string to generate a scene. + """ + def my_formatting(val): + if val is None or val == "": + return '""' + else: + return val + commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)} + --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)} + --navmesh {my_formatting(scene_data.navmesh)} + --output_dir {my_formatting(scene_data.output_dir)} + --generate_depth {int(generate_depth)} + --exist_ok {int(exist_ok)} + """ + commandline = " ".join(commandline.split()) + return commandline + +if __name__ == "__main__": + os.umask(2) + + parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available: + > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands + """) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true") + parser.add_argument("--scene", type=str, default="") + parser.add_argument("--scene_dataset_config_file", type=str, default="") + parser.add_argument("--navmesh", type=str, default="") + + parser.add_argument("--generate_depth", type=int, default=1) + parser.add_argument("--exist_ok", type=int, default=0) + + kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000) + + args = parser.parse_args() + generate_depth=bool(args.generate_depth) + exist_ok = bool(args.exist_ok) + + if args.list_commands: + # Listing scenes available... + scenes_data = list_scenes_available(base_output_dir=args.output_dir) + + for scene_data in scenes_data: + print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok)) + else: + if args.scene == "" or args.output_dir == "": + print("Missing scene or output dir argument!") + print(parser.format_help()) + else: + generate_multiview_images_for_scene(scene=args.scene, + scene_dataset_config_file = args.scene_dataset_config_file, + navmesh = args.navmesh, + output_dir = args.output_dir, + exist_ok=exist_ok, + generate_depth=generate_depth, + **kwargs) \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..91e5f923b836a645caf5d8e4aacc425047e3c144 --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py @@ -0,0 +1,390 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +import numpy as np +import quaternion +import habitat_sim +import json +from sklearn.neighbors import NearestNeighbors +import cv2 + +# OpenCV to habitat camera convention transformation +R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0) +R_HABITAT2OPENCV = R_OPENCV2HABITAT.T +DEG2RAD = np.pi / 180 + +def compute_camera_intrinsics(height, width, hfov): + f = width/2 / np.tan(hfov/2 * np.pi/180) + cu, cv = width/2, height/2 + return f, cu, cv + +def compute_camera_pose_opencv_convention(camera_position, camera_orientation): + R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT + t_cam2world = np.asarray(camera_position) + return R_cam2world, t_cam2world + +def compute_pointmap(depthmap, hfov): + """ Compute a HxWx3 pointmap in camera frame from a HxW depth map.""" + height, width = depthmap.shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + # Cast depth map to point + z_cam = depthmap + u, v = np.meshgrid(range(width), range(height)) + x_cam = (u - cu) / f * z_cam + y_cam = (v - cv) / f * z_cam + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1) + return X_cam + +def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation): + """Return a 3D point cloud corresponding to valid pixels of the depth map""" + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation) + + X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov) + valid_mask = (X_cam[:,:,2] != 0.0) + + X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()] + X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3) + return X_world + +def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False): + """ + Compute 'overlapping' metrics based on a distance threshold between two point clouds. + """ + nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2) + distances, indices = nbrs.kneighbors(pointcloud1) + intersection1 = np.count_nonzero(distances.flatten() < distance_threshold) + + data = {"intersection1": intersection1, + "size1": len(pointcloud1)} + if compute_symmetric: + nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1) + distances, indices = nbrs.kneighbors(pointcloud2) + intersection2 = np.count_nonzero(distances.flatten() < distance_threshold) + data["intersection2"] = intersection2 + data["size2"] = len(pointcloud2) + + return data + +def _append_camera_parameters(observation, hfov, camera_location, camera_rotation): + """ + Add camera parameters to the observation dictionnary produced by Habitat-Sim + In-place modifications. + """ + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation) + height, width = observation['depth'].shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + K = np.asarray([[f, 0, cu], + [0, f, cv], + [0, 0, 1.0]]) + observation["camera_intrinsics"] = K + observation["t_cam2world"] = t_cam2world + observation["R_cam2world"] = R_cam2world + +def look_at(eye, center, up, return_cam2world=True): + """ + Return camera pose looking at a given center point. + Analogous of gluLookAt function, using OpenCV camera convention. + """ + z = center - eye + z /= np.linalg.norm(z, axis=-1, keepdims=True) + y = -up + y = y - np.sum(y * z, axis=-1, keepdims=True) * z + y /= np.linalg.norm(y, axis=-1, keepdims=True) + x = np.cross(y, z, axis=-1) + + if return_cam2world: + R = np.stack((x, y, z), axis=-1) + t = eye + else: + # World to camera transformation + # Transposed matrix + R = np.stack((x, y, z), axis=-2) + t = - np.einsum('...ij, ...j', R, eye) + return R, t + +def look_at_for_habitat(eye, center, up, return_cam2world=True): + R, t = look_at(eye, center, up) + orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T) + return orientation, t + +def generate_orientation_noise(pan_range, tilt_range, roll_range): + return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP) + * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT) + * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT)) + + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + +class MultiviewHabitatSimGenerator: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + resolution = (240, 320), + views_count=2, + hfov = 60, + gpu_id = 0, + size = 10000, + minimum_covisibility = 0.5, + transform = None): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.resolution = resolution + self.views_count = views_count + assert(self.views_count >= 1) + self.hfov = hfov + self.gpu_id = gpu_id + self.size = size + self.transform = transform + + # Noise added to camera orientation + self.pan_range = (-3, 3) + self.tilt_range = (-10, 10) + self.roll_range = (-5, 5) + + # Height range to sample cameras + self.height_range = (1.2, 1.8) + + # Random steps between the camera views + self.random_steps_count = 5 + self.random_step_variance = 2.0 + + # Minimum fraction of the scene which should be valid (well defined depth) + self.minimum_valid_fraction = 0.7 + + # Distance threshold to see to select pairs + self.distance_threshold = 0.05 + # Minimum IoU of a view point cloud with respect to the reference view to be kept. + self.minimum_covisibility = minimum_covisibility + + # Maximum number of retries. + self.max_attempts_count = 100 + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly + if self.seed == None: + # Re-seed numpy generator + np.random.seed() + self.seed = np.random.randint(2**32-1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "": + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = "depth" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.resolution + depth_sensor_spec.hfov = self.hfov + depth_sensor_spec.position = [0.0, 0.0, 0] + depth_sensor_spec.orientation + + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = "color" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.resolution + rgb_sensor_spec.hfov = self.hfov + rgb_sensor_spec.position = [0.0, 0.0, 0] + agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec]) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + # Use pre-computed navmesh when available (usually better than those generated automatically) + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + if not self.sim.pathfinder.is_loaded: + # Try to compute a navmesh + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + # Ensure that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})") + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + self.sim.close() + + def __del__(self): + self.sim.close() + + def __len__(self): + return self.size + + def sample_random_viewpoint(self): + """ Sample a random viewpoint using the navmesh """ + nav_point = self.sim.pathfinder.get_random_navigable_point() + + # Sample a random viewpoint height + viewpoint_height = np.random.uniform(*self.height_range) + viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP + viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return viewpoint_position, viewpoint_orientation, nav_point + + def sample_other_random_viewpoint(self, observed_point, nav_point): + """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point.""" + other_nav_point = nav_point + + walk_directions = self.random_step_variance * np.asarray([1,0,1]) + for i in range(self.random_steps_count): + temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3)) + # Snapping may return nan when it fails + if not np.isnan(temp[0]): + other_nav_point = temp + + other_viewpoint_height = np.random.uniform(*self.height_range) + other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP + + # Set viewing direction towards the central point + rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True) + rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return position, rotation, other_nav_point + + def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud): + """ Check if a viewpoint is valid and overlaps significantly with a reference one. """ + # Observation + pixels_count = self.resolution[0] * self.resolution[1] + valid_fraction = len(other_pointcloud) / pixels_count + assert valid_fraction <= 1.0 and valid_fraction >= 0.0 + overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True) + covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count) + is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility) + return is_valid, valid_fraction, covisibility + + def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation): + """ Check if a viewpoint is valid and overlaps significantly with a reference one. """ + # Observation + other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation) + return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + + def render_viewpoint(self, viewpoint_position, viewpoint_orientation): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation) + return viewpoint_observations + + def __getitem__(self, useless_idx): + ref_position, ref_orientation, nav_point = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + # Extract point cloud + ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov, + camera_position=ref_position, camera_rotation=ref_orientation) + + pixels_count = self.resolution[0] * self.resolution[1] + ref_valid_fraction = len(ref_pointcloud) / pixels_count + assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0 + if ref_valid_fraction < self.minimum_valid_fraction: + # This should produce a recursion error at some point when something is very wrong. + return self[0] + # Pick an reference observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + + # Add the first image as reference + viewpoints_observations = [ref_observations] + viewpoints_covisibility = [ref_valid_fraction] + viewpoints_positions = [ref_position] + viewpoints_orientations = [quaternion.as_float_array(ref_orientation)] + viewpoints_clouds = [ref_pointcloud] + viewpoints_valid_fractions = [ref_valid_fraction] + + for _ in range(self.views_count - 1): + # Generate an other viewpoint using some dummy random walk + successful_sampling = False + for sampling_attempt in range(self.max_attempts_count): + position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point) + # Observation + other_viewpoint_observations = self.render_viewpoint(position, rotation) + other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation) + + is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + if is_valid: + successful_sampling = True + break + if not successful_sampling: + print("WARNING: Maximum number of attempts reached.") + # Dirty hack, try using a novel original viewpoint + return self[0] + viewpoints_observations.append(other_viewpoint_observations) + viewpoints_covisibility.append(covisibility) + viewpoints_positions.append(position) + viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding. + viewpoints_clouds.append(other_pointcloud) + viewpoints_valid_fractions.append(valid_fraction) + + # Estimate relations between all pairs of images + pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations))) + for i in range(len(viewpoints_observations)): + pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i] + for j in range(i+1, len(viewpoints_observations)): + overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True) + pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count + pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count + + # IoU is relative to the image 0 + data = {"observations": viewpoints_observations, + "positions": np.asarray(viewpoints_positions), + "orientations": np.asarray(viewpoints_orientations), + "covisibility_ratios": np.asarray(viewpoints_covisibility), + "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float), + "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float), + } + + if self.transform is not None: + data = self.transform(data) + return data + + def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False): + """ + Return a list of images corresponding to a spiral trajectory from a random starting point. + Useful to generate nice visualisations. + Use an even number of half turns to get a nice "C1-continuous" loop effect + """ + ref_position, ref_orientation, navpoint = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov, + camera_position=ref_position, camera_rotation=ref_orientation) + pixels_count = self.resolution[0] * self.resolution[1] + if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction: + # Dirty hack: ensure that the valid part of the image is significant + return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation) + + # Pick an observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation) + + images = [] + is_valid = [] + # Spiral trajectory, use_constant orientation + for i, alpha in enumerate(np.linspace(0, 1, images_count)): + r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius + theta = alpha * half_turns * np.pi + x = r * np.cos(theta) + y = r * np.sin(theta) + z = 0.0 + position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten() + if use_constant_orientation: + orientation = ref_orientation + else: + # trajectory looking at a mean point in front of the ref observation + orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP) + observations = self.render_viewpoint(position, orientation) + images.append(observations['color'][...,:3]) + _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation) + is_valid.append(_is_valid) + return images, np.all(is_valid) \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..10672a01f7dd615d3b4df37781f7f6f97e753ba6 --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py @@ -0,0 +1,69 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +""" +Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere. +""" +import os +import glob +from tqdm import tqdm +import shutil +import json +from datasets.habitat_sim.paths import * +import argparse +import collections + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("output_dir") + args = parser.parse_args() + + input_dirname = args.input_dir + output_dirname = args.output_dir + + input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True) + + images_count = collections.defaultdict(lambda : 0) + + os.makedirs(output_dirname) + for input_filename in tqdm(input_metadata_filenames): + # Ignore empty files + with open(input_filename, "r") as f: + original_metadata = json.load(f) + if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0: + print("No views in", input_filename) + continue + + relpath = os.path.relpath(input_filename, input_dirname) + print(relpath) + + # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability. + # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern. + scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True)) + metadata = dict() + for key, value in original_metadata.items(): + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + known_path = False + for dataset, dataset_path in scenes_dataset_paths.items(): + if value.startswith(dataset_path): + value = os.path.join(dataset, os.path.relpath(value, dataset_path)) + known_path = True + break + if not known_path: + raise KeyError("Unknown path:" + value) + metadata[key] = value + + # Compile some general statistics while packing data + scene_split = metadata["scene"].split("/") + upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0] + images_count[upper_level] += len(metadata["multiviews"]) + + output_filename = os.path.join(output_dirname, relpath) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + with open(output_filename, "w") as f: + json.dump(metadata, f) + + # Print statistics + print("Images count:") + for upper_level, count in images_count.items(): + print(f"- {upper_level}: {count}") \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/paths.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..4d63b5fa29c274ddfeae084734a35ba66d7edee8 --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/paths.py @@ -0,0 +1,129 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Paths to Habitat-Sim scenes +""" + +import os +import json +import collections +from tqdm import tqdm + + +# Hardcoded path to the different scene datasets +SCENES_DATASET = { + "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/", + "gibson": "./data/habitat-sim-data/scene_datasets/gibson/", + "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/", + "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/", + "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/", + "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/", + "scannet": "./data/habitat-sim/scene_datasets/scannet/" +} + +SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"]) + +def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]): + scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json") + scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"] + navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx]) + # Add scene + data = SceneData(scene_dataset_config_file=scene_dataset_config_file, + scene = scenes[idx] + ".scene_instance.json", + navmesh = os.path.join(base_path, navmeshes[idx]), + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + +def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]): + scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json") + scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], []) + navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx]) + data = SceneData(scene_dataset_config_file=scene_dataset_config_file, + scene = scenes[idx], + navmesh = "", + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + +def list_replica_scenes(base_output_dir, base_path): + scenes_data = [] + for scene_id in os.listdir(base_path): + scene = os.path.join(base_path, scene_id, "mesh.ply") + navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it + scene_dataset_config_file = "" + output_dir = os.path.join(base_output_dir, scene_id) + # Add scene only if it does not exist already, or if exist_ok + data = SceneData(scene_dataset_config_file = scene_dataset_config_file, + scene = scene, + navmesh = navmesh, + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + + +def list_scenes(base_output_dir, base_path): + """ + Generic method iterating through a base_path folder to find scenes. + """ + scenes_data = [] + for root, dirs, files in os.walk(base_path, followlinks=True): + folder_scenes_data = [] + for file in files: + name, ext = os.path.splitext(file) + if ext == ".glb": + scene = os.path.join(root, name + ".glb") + navmesh = os.path.join(root, name + ".navmesh") + if not os.path.exists(navmesh): + navmesh = "" + relpath = os.path.relpath(root, base_path) + output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name)) + data = SceneData(scene_dataset_config_file="", + scene = scene, + navmesh = navmesh, + output_dir = output_dir) + folder_scenes_data.append(data) + + # Specific check for HM3D: + # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version. + basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")] + if len(basis_scenes) != 0: + folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)] + + scenes_data.extend(folder_scenes_data) + return scenes_data + +def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET): + scenes_data = [] + + # HM3D + for split in ("minival", "train", "val", "examples"): + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"), + base_path=f"{scenes_dataset_paths['hm3d']}/{split}") + + # Gibson + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"), + base_path=scenes_dataset_paths["gibson"]) + + # Habitat test scenes (just a few) + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"), + base_path=scenes_dataset_paths["habitat-test-scenes"]) + + # ReplicaCAD (baked lightning) + scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir) + + # ScanNet + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"), + base_path=scenes_dataset_paths["scannet"]) + + # Replica + list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"), + base_path=scenes_dataset_paths["replica"]) + return scenes_data diff --git a/imcui/third_party/dust3r/croco/datasets/pairs_dataset.py b/imcui/third_party/dust3r/croco/datasets/pairs_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9f107526b34e154d9013a9a7a0bde3d5ff6f581c --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/pairs_dataset.py @@ -0,0 +1,109 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +from torch.utils.data import Dataset +from PIL import Image + +from datasets.transforms import get_pair_transforms + +def load_image(impath): + return Image.open(impath) + +def load_pairs_from_cache_file(fname, root=''): + assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, 'r') as fid: + lines = fid.read().strip().splitlines() + pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines] + return pairs + +def load_pairs_from_list_file(fname, root=''): + assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, 'r') as fid: + lines = fid.read().strip().splitlines() + pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')] + return pairs + + +def write_cache_file(fname, pairs, root=''): + if len(root)>0: + if not root.endswith('/'): root+='/' + assert os.path.isdir(root) + s = '' + for im1, im2 in pairs: + if len(root)>0: + assert im1.startswith(root), im1 + assert im2.startswith(root), im2 + s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):]) + with open(fname, 'w') as fid: + fid.write(s[:-1]) + +def parse_and_cache_all_pairs(dname, data_dir='./data/'): + if dname=='habitat_release': + dirname = os.path.join(data_dir, 'habitat_release') + assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname + cache_file = os.path.join(dirname, 'pairs.txt') + assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file + + print('Parsing pairs for dataset: '+dname) + pairs = [] + for root, dirs, files in os.walk(dirname): + if 'val' in root: continue + dirs.sort() + pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')] + print('Found {:,} pairs'.format(len(pairs))) + print('Writing cache to: '+cache_file) + write_cache_file(cache_file, pairs, root=dirname) + + else: + raise NotImplementedError('Unknown dataset: '+dname) + +def dnames_to_image_pairs(dnames, data_dir='./data/'): + """ + dnames: list of datasets with image pairs, separated by + + """ + all_pairs = [] + for dname in dnames.split('+'): + if dname=='habitat_release': + dirname = os.path.join(data_dir, 'habitat_release') + assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname + cache_file = os.path.join(dirname, 'pairs.txt') + assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file + pairs = load_pairs_from_cache_file(cache_file, root=dirname) + elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']: + dirname = os.path.join(data_dir, dname+'_crops') + assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) + list_file = os.path.join(dirname, 'listing.txt') + assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file) + pairs = load_pairs_from_list_file(list_file, root=dirname) + print(' {:s}: {:,} pairs'.format(dname, len(pairs))) + all_pairs += pairs + if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs))) + return all_pairs + + +class PairsDataset(Dataset): + + def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'): + super().__init__() + self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) + self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize) + + def __len__(self): + return len(self.image_pairs) + + def __getitem__(self, index): + im1path, im2path = self.image_pairs[index] + im1 = load_image(im1path) + im2 = load_image(im2path) + if self.transforms is not None: im1, im2 = self.transforms(im1, im2) + return im1, im2 + + +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset") + parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") + parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset") + args = parser.parse_args() + parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) diff --git a/imcui/third_party/dust3r/croco/datasets/transforms.py b/imcui/third_party/dust3r/croco/datasets/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..216bac61f8254fd50e7f269ee80301f250a2d11e --- /dev/null +++ b/imcui/third_party/dust3r/croco/datasets/transforms.py @@ -0,0 +1,95 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +import torchvision.transforms +import torchvision.transforms.functional as F + +# "Pair": apply a transform on a pair +# "Both": apply the exact same transform to both images + +class ComposePair(torchvision.transforms.Compose): + def __call__(self, img1, img2): + for t in self.transforms: + img1, img2 = t(img1, img2) + return img1, img2 + +class NormalizeBoth(torchvision.transforms.Normalize): + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + +class ToTensorBoth(torchvision.transforms.ToTensor): + def __call__(self, img1, img2): + img1 = super().__call__(img1) + img2 = super().__call__(img2) + return img1, img2 + +class RandomCropPair(torchvision.transforms.RandomCrop): + # the crop will be intentionally different for the two images with this class + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + +class ColorJitterPair(torchvision.transforms.ColorJitter): + # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob + def __init__(self, assymetric_prob, **kwargs): + super().__init__(**kwargs) + self.assymetric_prob = assymetric_prob + def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return img + + def forward(self, img1, img2): + + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) + if torch.rand(1) < self.assymetric_prob: # assymetric: + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) + return img1, img2 + +def get_pair_transforms(transform_str, totensor=True, normalize=True): + # transform_str is eg crop224+color + trfs = [] + for s in transform_str.split('+'): + if s.startswith('crop'): + size = int(s[len('crop'):]) + trfs.append(RandomCropPair(size)) + elif s=='acolor': + trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0)) + elif s=='': # if transform_str was "" + pass + else: + raise NotImplementedError('Unknown augmentation: '+s) + + if totensor: + trfs.append( ToTensorBoth() ) + if normalize: + trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) + + if len(trfs)==0: + return None + elif len(trfs)==1: + return trfs + else: + return ComposePair(trfs) + + + + + diff --git a/imcui/third_party/dust3r/croco/demo.py b/imcui/third_party/dust3r/croco/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..91b80ccc5c98c18e20d1ce782511aa824ef28f77 --- /dev/null +++ b/imcui/third_party/dust3r/croco/demo.py @@ -0,0 +1,55 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +from models.croco import CroCoNet +from PIL import Image +import torchvision.transforms +from torchvision.transforms import ToTensor, Normalize, Compose + +def main(): + device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu') + + # load 224x224 images and transform them to tensor + imagenet_mean = [0.485, 0.456, 0.406] + imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True) + imagenet_std = [0.229, 0.224, 0.225] + imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True) + trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)]) + image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) + image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) + + # load model + ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu') + model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device) + model.eval() + msg = model.load_state_dict(ckpt['model'], strict=True) + + # forward + with torch.inference_mode(): + out, mask, target = model(image1, image2) + + # the output is normalized, thus use the mean/std of the actual image to go back to RGB space + patchified = model.patchify(image1) + mean = patchified.mean(dim=-1, keepdim=True) + var = patchified.var(dim=-1, keepdim=True) + decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean) + # undo imagenet normalization, prepare masked image + decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor + input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor + ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor + image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None]) + masked_input_image = ((1 - image_masks) * input_image) + + # make visualization + visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4 + B, C, H, W = visualization.shape + visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W) + visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1)) + fname = "demo_output.png" + visualization.save(fname) + print('Visualization save in '+fname) + + +if __name__=="__main__": + main() diff --git a/imcui/third_party/dust3r/croco/models/blocks.py b/imcui/third_party/dust3r/croco/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..18133524f0ae265b0bd8d062d7c9eeaa63858a9b --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/blocks.py @@ -0,0 +1,241 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Main encoder/decoder blocks +# -------------------------------------------------------- +# References: +# timm +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py + + +import torch +import torch.nn as nn + +from itertools import repeat +import collections.abc + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class Attention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x, xpos): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3) + q, k, v = [qkv[:,:,i] for i in range(3)] + # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, xpos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class CrossAttention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward(self, query, key, value, qpos, kpos): + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class DecoderBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x, y + + +# patch embedding +class PositionGetter(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos + +class PatchEmbed(nn.Module): + """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.position_getter = PositionGetter() + + def forward(self, x): + B, C, H, W = x.shape + torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + def _init_weights(self): + w = self.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + diff --git a/imcui/third_party/dust3r/croco/models/criterion.py b/imcui/third_party/dust3r/croco/models/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..11696c40865344490f23796ea45e8fbd5e654731 --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/criterion.py @@ -0,0 +1,37 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Criterion to train CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import torch + +class MaskedMSE(torch.nn.Module): + + def __init__(self, norm_pix_loss=False, masked=True): + """ + norm_pix_loss: normalize each patch by their pixel mean and variance + masked: compute loss over the masked patches only + """ + super().__init__() + self.norm_pix_loss = norm_pix_loss + self.masked = masked + + def forward(self, pred, mask, target): + + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + if self.masked: + loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches + else: + loss = loss.mean() # mean loss + return loss diff --git a/imcui/third_party/dust3r/croco/models/croco.py b/imcui/third_party/dust3r/croco/models/croco.py new file mode 100644 index 0000000000000000000000000000000000000000..14c68634152d75555b4c35c25af268394c5821fe --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/croco.py @@ -0,0 +1,249 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# CroCo model during pretraining +# -------------------------------------------------------- + + + +import torch +import torch.nn as nn +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 +from functools import partial + +from models.blocks import Block, DecoderBlock, PatchEmbed +from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D +from models.masking import RandomMask + + +class CroCoNet(nn.Module): + + def __init__(self, + img_size=224, # input image size + patch_size=16, # patch_size + mask_ratio=0.9, # ratios of masked tokens + enc_embed_dim=768, # encoder feature dimension + enc_depth=12, # encoder depth + enc_num_heads=12, # encoder number of heads in the transformer block + dec_embed_dim=512, # decoder feature dimension + dec_depth=8, # decoder depth + dec_num_heads=16, # decoder number of heads in the transformer block + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder + pos_embed='cosine', # positional embedding (either cosine or RoPE100) + ): + + super(CroCoNet, self).__init__() + + # patch embeddings (with initialization done as in MAE) + self._set_patch_embed(img_size, patch_size, enc_embed_dim) + + # mask generations + self._set_mask_generator(self.patch_embed.num_patches, mask_ratio) + + self.pos_embed = pos_embed + if pos_embed=='cosine': + # positional embedding of the encoder + enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float()) + # positional embedding of the decoder + dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float()) + # pos embedding in each block + self.rope = None # nothing for cosine + elif pos_embed.startswith('RoPE'): # eg RoPE100 + self.enc_pos_embed = None # nothing to add in the encoder with RoPE + self.dec_pos_embed = None # nothing to add in the decoder with RoPE + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(pos_embed[len('RoPE'):]) + self.rope = RoPE2D(freq=freq) + else: + raise NotImplementedError('Unknown pos_embed '+pos_embed) + + # transformer for the encoder + self.enc_depth = enc_depth + self.enc_embed_dim = enc_embed_dim + self.enc_blocks = nn.ModuleList([ + Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope) + for i in range(enc_depth)]) + self.enc_norm = norm_layer(enc_embed_dim) + + # masked tokens + self._set_mask_token(dec_embed_dim) + + # decoder + self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec) + + # prediction head + self._set_prediction_head(dec_embed_dim, patch_size) + + # initializer weights + self.initialize_weights() + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) + + def _set_mask_generator(self, num_patches, mask_ratio): + self.mask_generator = RandomMask(num_patches, mask_ratio) + + def _set_mask_token(self, dec_embed_dim): + self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) + + def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec): + self.dec_depth = dec_depth + self.dec_embed_dim = dec_embed_dim + # transfer from encoder to decoder + self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) + # transformer for the decoder + self.dec_blocks = nn.ModuleList([ + DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope) + for i in range(dec_depth)]) + # final norm layer + self.dec_norm = norm_layer(dec_embed_dim) + + def _set_prediction_head(self, dec_embed_dim, patch_size): + self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) + + + def initialize_weights(self): + # patch embed + self.patch_embed._init_weights() + # mask tokens + if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02) + # linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _encode_image(self, image, do_mask=False, return_all_blocks=False): + """ + image has B x 3 x img_size x img_size + do_mask: whether to perform masking or not + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + """ + # embed the image into patches (x has size B x Npatches x C) + # and get position if each return patch (pos has size B x Npatches x 2) + x, pos = self.patch_embed(image) + # add positional embedding without cls token + if self.enc_pos_embed is not None: + x = x + self.enc_pos_embed[None,...] + # apply masking + B,N,C = x.size() + if do_mask: + masks = self.mask_generator(x) + x = x[~masks].view(B, -1, C) + posvis = pos[~masks].view(B, -1, 2) + else: + B,N,C = x.size() + masks = torch.zeros((B,N), dtype=bool) + posvis = pos + # now apply the transformer encoder and normalization + if return_all_blocks: + out = [] + for blk in self.enc_blocks: + x = blk(x, posvis) + out.append(x) + out[-1] = self.enc_norm(out[-1]) + return out, pos, masks + else: + for blk in self.enc_blocks: + x = blk(x, posvis) + x = self.enc_norm(x) + return x, pos, masks + + def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): + """ + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + + masks1 can be None => assume image1 fully visible + """ + # encoder to decoder layer + visf1 = self.decoder_embed(feat1) + f2 = self.decoder_embed(feat2) + # append masked tokens to the sequence + B,Nenc,C = visf1.size() + if masks1 is None: # downstreams + f1_ = visf1 + else: # pretraining + Ntotal = masks1.size(1) + f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) + f1_[~masks1] = visf1.view(B * Nenc, C) + # add positional embedding + if self.dec_pos_embed is not None: + f1_ = f1_ + self.dec_pos_embed + f2 = f2 + self.dec_pos_embed + # apply Transformer blocks + out = f1_ + out2 = f2 + if return_all_blocks: + _out, out = out, [] + for blk in self.dec_blocks: + _out, out2 = blk(_out, out2, pos1, pos2) + out.append(_out) + out[-1] = self.dec_norm(out[-1]) + else: + for blk in self.dec_blocks: + out, out2 = blk(out, out2, pos1, pos2) + out = self.dec_norm(out) + return out + + def patchify(self, imgs): + """ + imgs: (B, 3, H, W) + x: (B, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x, channels=3): + """ + x: (N, L, patch_size**2 *channels) + imgs: (N, 3, H, W) + """ + patch_size = self.patch_embed.patch_size[0] + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) + return imgs + + def forward(self, img1, img2): + """ + img1: tensor of size B x 3 x img_size x img_size + img2: tensor of size B x 3 x img_size x img_size + + out will be B x N x (3*patch_size*patch_size) + masks are also returned as B x N just in case + """ + # encoder of the masked first image + feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) + # encoder of the second image + feat2, pos2, _ = self._encode_image(img2, do_mask=False) + # decoder + decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) + # prediction head + out = self.prediction_head(decfeat) + # get target + target = self.patchify(img1) + return out, mask1, target diff --git a/imcui/third_party/dust3r/croco/models/croco_downstream.py b/imcui/third_party/dust3r/croco/models/croco_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..159dfff4d2c1461bc235e21441b57ce1e2088f76 --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/croco_downstream.py @@ -0,0 +1,122 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# CroCo model for downstream tasks +# -------------------------------------------------------- + +import torch + +from .croco import CroCoNet + + +def croco_args_from_ckpt(ckpt): + if 'croco_kwargs' in ckpt: # CroCo v2 released models + return ckpt['croco_kwargs'] + elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release + s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" + assert s.startswith('CroCoNet(') + return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it + else: # CroCo v1 released models + return dict() + +class CroCoDownstreamMonocularEncoder(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for monocular downstream task, only using the encoder. + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + NOTE: It works by *calling super().__init__() but with redefined setters + + """ + super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_decoder(self, *args, **kwargs): + """ No decoder """ + return + + def _set_prediction_head(self, *args, **kwargs): + """ No 'prediction head' for downstream tasks.""" + return + + def forward(self, img): + """ + img if of size batch_size x 3 x h x w + """ + B, C, H, W = img.size() + img_info = {'height': H, 'width': W} + need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers) + return self.head(out, img_info) + + +class CroCoDownstreamBinocular(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for binocular downstream task + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + """ + super(CroCoDownstreamBinocular, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head for downstream tasks, define your own head """ + return + + def encode_image_pairs(self, img1, img2, return_all_blocks=False): + """ run encoder for a pair of images + it is actually ~5% faster to concatenate the images along the batch dimension + than to encode them separately + """ + ## the two commented lines below is the naive version with separate encoding + #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) + #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) + ## and now the faster version + out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks ) + if return_all_blocks: + out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) + out2 = out2[-1] + else: + out,out2 = out.chunk(2, dim=0) + pos,pos2 = pos.chunk(2, dim=0) + return out, out2, pos, pos2 + + def forward(self, img1, img2): + B, C, H, W = img1.size() + img_info = {'height': H, 'width': W} + return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks) + if return_all_blocks: + decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks) + decout = out+decout + else: + decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks) + return self.head(decout, img_info) \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/models/curope/__init__.py b/imcui/third_party/dust3r/croco/models/curope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25e3d48a162760260826080f6366838e83e26878 --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/curope/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from .curope2d import cuRoPE2D diff --git a/imcui/third_party/dust3r/croco/models/curope/curope2d.py b/imcui/third_party/dust3r/croco/models/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c12f8c529e9a889b5ac20c5767158f238e17d --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/curope/curope2d.py @@ -0,0 +1,40 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func (torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d( tokens, positions, base, F0 ) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d( grad_res, positions, base, -F0 ) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) + return tokens \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/models/curope/setup.py b/imcui/third_party/dust3r/croco/models/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..230632ed05e309200e8f93a3a852072333975009 --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/curope/setup.py @@ -0,0 +1,34 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ + # '-gencode', 'arch=compute_70,code=sm_70', + # '-gencode', 'arch=compute_75,code=sm_75', + # '-gencode', 'arch=compute_80,code=sm_80', + # '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name = 'curope', + ext_modules = [ + CUDAExtension( + name='curope', + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args = dict( + nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, + cxx=['-O3']) + ) + ], + cmdclass = { + 'build_ext': BuildExtension + }) diff --git a/imcui/third_party/dust3r/croco/models/dpt_block.py b/imcui/third_party/dust3r/croco/models/dpt_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ddfb74e2769ceca88720d4c730e00afd71c763 --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/dpt_block.py @@ -0,0 +1,450 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# DPT head for ViTs +# -------------------------------------------------------- +# References: +# https://github.com/isl-org/DPT +# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from typing import Union, Tuple, Iterable, List, Optional, Dict + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList([ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ]) + + return scratch + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + width_ratio=1, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + self.width_ratio = width_ratio + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + if self.width_ratio != 1: + res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear') + + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if self.width_ratio != 1: + # and output.shape[3] < self.width_ratio * output.shape[2] + #size=(image.shape[]) + if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio: + shape = 3 * output.shape[3] + else: + shape = int(self.width_ratio * 2 * output.shape[2]) + output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear') + else: + output = nn.functional.interpolate(output, scale_factor=2, + mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + return output + +def make_fusion_block(features, use_bn, width_ratio=1): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + width_ratio=width_ratio, + ) + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_cahnnels: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__(self, + num_channels: int = 1, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ('rgb',), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + last_dim: int = 32, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = 'regression', + output_width_ratio=1, + **kwargs): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None + self.head_type = head_type + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + + if self.head_type == 'regression': + # The "DPTDepthModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0) + ) + elif self.head_type == 'semseg': + # The "DPTSegmentationModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc=768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + #print(dim_tokens_enc) + + # Set up activation postprocessing layers + if isinstance(dim_tokens_enc, int): + dim_tokens_enc = 4 * [dim_tokens_enc] + + self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc] + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[0], + out_channels=self.layer_dims[0], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, stride=4, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[1], + out_channels=self.layer_dims[1], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, stride=2, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[2], + out_channels=self.layer_dims[2], + kernel_size=1, stride=1, padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[3], + out_channels=self.layer_dims[3], + kernel_size=1, stride=1, padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, stride=2, padding=1, + ) + ) + + self.act_postprocess = nn.ModuleList([ + self.act_1_postprocess, + self.act_2_postprocess, + self.act_3_postprocess, + self.act_4_postprocess + ]) + + def adapt_tokens(self, encoder_tokens): + # Adapt tokens + x = [] + x.append(encoder_tokens[:, :]) + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], image_size): + #input_info: Dict): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + H, W = image_size + + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out diff --git a/imcui/third_party/dust3r/croco/models/head_downstream.py b/imcui/third_party/dust3r/croco/models/head_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..bd40c91ba244d6c3522c6efd4ed4d724b7bdc650 --- /dev/null +++ b/imcui/third_party/dust3r/croco/models/head_downstream.py @@ -0,0 +1,58 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Heads for downstream tasks +# -------------------------------------------------------- + +""" +A head is a module where the __init__ defines only the head hyperparameters. +A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. +The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' +""" + +import torch +import torch.nn as nn +from .dpt_block import DPTOutputAdapter + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for CroCo. + by default, hooks_idx will be equal to: + * for encoder-only: 4 equally spread layers + * for encoder+decoder: last encoder + 3 equally spread layers of the decoder + """ + + def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768], + output_width_ratio=1, num_channels=1, postprocess=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_blocks = True # backbone needs to return all layers + self.postprocess = postprocess + self.output_width_ratio = output_width_ratio + self.num_channels = num_channels + self.hooks_idx = hooks_idx + self.layer_dims = layer_dims + + def setup(self, croconet): + dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels} + if self.hooks_idx is None: + if hasattr(croconet, 'dec_blocks'): # encoder + decoder + step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] + hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + else: # encoder only + step = croconet.enc_depth//4 + hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + self.hooks_idx = hooks_idx + print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}') + dpt_args['hooks'] = self.hooks_idx + dpt_args['layer_dims'] = self.layer_dims + self.dpt = DPTOutputAdapter(**dpt_args) + dim_tokens = [croconet.enc_embed_dim if hook0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +#---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +#---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D,seq_len,device,dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D,seq_len,device,dtype] = (cos,sin) + return self.cache[D,seq_len,device,dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim==2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:,:,0], cos, sin) + x = self.apply_rope1d(x, positions[:,:,1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/pretrain.py b/imcui/third_party/dust3r/croco/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..2c45e488015ef5380c71d0381ff453fdb860759e --- /dev/null +++ b/imcui/third_party/dust3r/croco/pretrain.py @@ -0,0 +1,254 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Pre-training CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time +import math +from pathlib import Path +from typing import Iterable + +import torch +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +import utils.misc as misc +from utils.misc import NativeScalerWithGradNormCount as NativeScaler +from models.croco import CroCoNet +from models.criterion import MaskedMSE +from datasets.pairs_dataset import PairsDataset + + +def get_args_parser(): + parser = argparse.ArgumentParser('CroCo pre-training', add_help=False) + # model and criterion + parser.add_argument('--model', default='CroCoNet()', type=str, help="string containing the model to build") + parser.add_argument('--norm_pix_loss', default=1, choices=[0,1], help="apply per-patch mean/std normalization before applying the loss") + # dataset + parser.add_argument('--dataset', default='habitat_release', type=str, help="training set") + parser.add_argument('--transforms', default='crop224+acolor', type=str, help="transforms to apply") # in the paper, we also use some homography and rotation, but find later that they were not useful or even harmful + # training + parser.add_argument('--seed', default=0, type=int, help="Random seed") + parser.add_argument('--batch_size', default=64, type=int, help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") + parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") + parser.add_argument('--max_epoch', default=400, type=int, help="Stop training at this epoch") + parser.add_argument('--accum_iter', default=1, type=int, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") + parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") + parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') + parser.add_argument('--amp', type=int, default=1, choices=[0,1], help="Use Automatic Mixed Precision for pretraining") + # others + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--save_freq', default=1, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') + parser.add_argument('--keep_freq', default=20, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') + parser.add_argument('--print_freq', default=20, type=int, help='frequence (number of iterations) to print infos while training') + # paths + parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output") + parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") + return parser + + + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + + print("output_dir: "+args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # auto resume + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + ## training dataset and loader + print('Building dataset for {:s} with transforms {:s}'.format(args.dataset, args.transforms)) + dataset = PairsDataset(args.dataset, trfs=args.transforms, data_dir=args.data_dir) + if world_size>1: + sampler_train = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + else: + sampler_train = torch.utils.data.RandomSampler(dataset) + data_loader_train = torch.utils.data.DataLoader( + dataset, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + + ## model + print('Loading model: {:s}'.format(args.model)) + model = eval(args.model) + print('Loading criterion: MaskedMSE(norm_pix_loss={:s})'.format(str(bool(args.norm_pix_loss)))) + criterion = MaskedMSE(norm_pix_loss=bool(args.norm_pix_loss)) + + model.to(device) + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True) + model_without_ddp = model.module + + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) # following timm: set wd as 0 for bias and norm layers + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training until {args.max_epoch} epochs") + start_time = time.time() + for epoch in range(args.start_epoch, args.max_epoch): + if world_size>1: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + log_writer=log_writer, + args=args + ) + + if args.output_dir and epoch % args.save_freq == 0 : + misc.save_model( + args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, fname='last') + + if args.output_dir and (epoch % args.keep_freq == 0 or epoch + 1 == args.max_epoch) and (epoch>0 or args.max_epoch==1): + misc.save_model( + args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch,} + + if args.output_dir and misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + log_writer=None, + args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + accum_iter = args.accum_iter + + optimizer.zero_grad() + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + for data_iter_step, (image1, image2) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + + # we use a per iteration lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + out, mask, target = model(image1, image2) + loss = criterion(out, mask, target) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and ((data_iter_step + 1) % (accum_iter*args.print_freq)) == 0: + # x-axis is based on epoch_1000x in the tensorboard, calibrating differences curves when batch size changes + epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) + log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('lr', lr, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/imcui/third_party/dust3r/croco/stereoflow/augmentor.py b/imcui/third_party/dust3r/croco/stereoflow/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..69e6117151988d94cbc4b385e0d88e982133bf10 --- /dev/null +++ b/imcui/third_party/dust3r/croco/stereoflow/augmentor.py @@ -0,0 +1,290 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Data augmentation for training stereo and flow +# -------------------------------------------------------- + +# References +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/stereo/transforms.py +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/transforms.py + + +import numpy as np +import random +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torchvision.transforms.functional as FF + +class StereoAugmentor(object): + + def __init__(self, crop_size, scale_prob=0.5, scale_xonly=True, lhth=800., lminscale=0.0, lmaxscale=1.0, hminscale=-0.2, hmaxscale=0.4, scale_interp_nearest=True, rightjitterprob=0.5, v_flip_prob=0.5, color_aug_asym=True, color_choice_prob=0.5): + self.crop_size = crop_size + self.scale_prob = scale_prob + self.scale_xonly = scale_xonly + self.lhth = lhth + self.lminscale = lminscale + self.lmaxscale = lmaxscale + self.hminscale = hminscale + self.hmaxscale = hmaxscale + self.scale_interp_nearest = scale_interp_nearest + self.rightjitterprob = rightjitterprob + self.v_flip_prob = v_flip_prob + self.color_aug_asym = color_aug_asym + self.color_choice_prob = color_choice_prob + + def _random_scale(self, img1, img2, disp): + ch,cw = self.crop_size + h,w = img1.shape[:2] + if self.scale_prob>0. and np.random.rand()1.: + scale_x = clip_scale + scale_y = scale_x if not self.scale_xonly else 1.0 + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + disp = cv2.resize(disp, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR if not self.scale_interp_nearest else cv2.INTER_NEAREST) * scale_x + return img1, img2, disp + + def _random_crop(self, img1, img2, disp): + h,w = img1.shape[:2] + ch,cw = self.crop_size + assert ch<=h and cw<=w, (img1.shape, h,w,ch,cw) + offset_x = np.random.randint(w - cw + 1) + offset_y = np.random.randint(h - ch + 1) + img1 = img1[offset_y:offset_y+ch,offset_x:offset_x+cw] + img2 = img2[offset_y:offset_y+ch,offset_x:offset_x+cw] + disp = disp[offset_y:offset_y+ch,offset_x:offset_x+cw] + return img1, img2, disp + + def _random_vflip(self, img1, img2, disp): + # vertical flip + if self.v_flip_prob>0 and np.random.rand() < self.v_flip_prob: + img1 = np.copy(np.flipud(img1)) + img2 = np.copy(np.flipud(img2)) + disp = np.copy(np.flipud(disp)) + return img1, img2, disp + + def _random_rotate_shift_right(self, img2): + if self.rightjitterprob>0. and np.random.rand() 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow = np.inf * np.ones([ht1, wd1, 2], dtype=np.float32) # invalid value every where, before we fill it with the correct ones + flow[yy, xx] = flow1 + return flow + + def spatial_transform(self, img1, img2, flow, dname): + + if np.random.rand() < self.spatial_aug_prob: + # randomly sample scale + ht, wd = img1.shape[:2] + clip_min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + min_scale, max_scale = self.min_scale, self.max_scale + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_x = np.clip(scale_x, clip_min_scale, None) + scale_y = np.clip(scale_y, clip_min_scale, None) + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = self._resize_flow(flow, scale_x, scale_y, factor=2.0 if dname=='Spring' else 1.0) + elif dname=="Spring": + flow = self._resize_flow(flow, 1.0, 1.0, factor=2.0) + + if self.h_flip_prob>0. and np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if self.v_flip_prob>0. and np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + # In case no cropping + if img1.shape[0] - self.crop_size[0] > 0: + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + else: + y0 = 0 + if img1.shape[1] - self.crop_size[1] > 0: + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + else: + x0 = 0 + + img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow, dname): + img1, img2, flow = self.spatial_transform(img1, img2, flow, dname) + img1, img2 = self.color_transform(img1, img2) + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + return img1, img2, flow \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/stereoflow/criterion.py b/imcui/third_party/dust3r/croco/stereoflow/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..57792ebeeee34827b317a4d32b7445837bb33f17 --- /dev/null +++ b/imcui/third_party/dust3r/croco/stereoflow/criterion.py @@ -0,0 +1,251 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Losses, metrics per batch, metrics per dataset +# -------------------------------------------------------- + +import torch +from torch import nn +import torch.nn.functional as F + +def _get_gtnorm(gt): + if gt.size(1)==1: # stereo + return gt + # flow + return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW + +############ losses without confidence + +class L1Loss(nn.Module): + + def __init__(self, max_gtnorm=None): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = False + + def _error(self, gt, predictions): + return torch.abs(gt-predictions) + + def forward(self, predictions, gt, inspect=False): + mask = torch.isfinite(gt) + if self.max_gtnorm is not None: + mask *= _get_gtnorm(gt).expand(-1,gt.size(1),-1,-1) which is a constant + + +class LaplacianLossBounded(nn.Module): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b + def __init__(self, max_gtnorm=10000., a=0.25, b=4.): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:,0,:,:] + if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant + +class LaplacianLossBounded2(nn.Module): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b + def __init__(self, max_gtnorm=None, a=3.0, b=3.0): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:,0,:,:] + if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant + +############## metrics per batch + +class StereoMetrics(nn.Module): + + def __init__(self, do_quantile=False): + super().__init__() + self.bad_ths = [0.5,1,2,3] + self.do_quantile = do_quantile + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + gtcopy = gt.clone() + mask = torch.isfinite(gtcopy) + gtcopy[~mask] = 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0 + Npx = mask.view(B,-1).sum(dim=1) + L1error = (torch.abs(gtcopy-predictions)*mask).view(B,-1) + L2error = (torch.square(gtcopy-predictions)*mask).view(B,-1) + # avgerr + metrics['avgerr'] = torch.mean(L1error.sum(dim=1)/Npx ) + # rmse + metrics['rmse'] = torch.sqrt(L2error.sum(dim=1)/Npx).mean(dim=0) + # err > t for t in [0.5,1,2,3] + for ths in self.bad_ths: + metrics['bad@{:.1f}'.format(ths)] = (((L1error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100 + return metrics + +class FlowMetrics(nn.Module): + def __init__(self): + super().__init__() + self.bad_ths = [1,3,5] + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + mask = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + Npx = mask.view(B,-1).sum(dim=1) + gtcopy = gt.clone() # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored + gtcopy[:,0,:,:][~mask] = 999999.0 + gtcopy[:,1,:,:][~mask] = 999999.0 + L1error = (torch.abs(gtcopy-predictions).sum(dim=1)*mask).view(B,-1) + L2error = (torch.sqrt(torch.sum(torch.square(gtcopy-predictions),dim=1))*mask).view(B,-1) + metrics['L1err'] = torch.mean(L1error.sum(dim=1)/Npx ) + metrics['EPE'] = torch.mean(L2error.sum(dim=1)/Npx ) + for ths in self.bad_ths: + metrics['bad@{:.1f}'.format(ths)] = (((L2error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100 + return metrics + +############## metrics per dataset +## we update the average and maintain the number of pixels while adding data batch per batch +## at the beggining, call reset() +## after each batch, call add_batch(...) +## at the end: call get_results() + +class StereoDatasetMetrics(nn.Module): + + def __init__(self): + super().__init__() + self.bad_ths = [0.5,1,2,3] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self._metrics = None + + def add_batch(self, predictions, gt): + assert predictions.size(1)==1, predictions.size() + assert gt.size(1)==1, gt.size() + if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ... + L1err = torch.minimum( torch.minimum( torch.minimum( + torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1), + torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1)) + valid = torch.isfinite(L1err) + else: + valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt-predictions),dim=1) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew + self.agg_N = Nnew + for i,th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L1err[valid]>th).sum().cpu() + + def _compute_metrics(self): + if self._metrics is not None: return + out = {} + out['L1err'] = self.agg_L1err.item() + for i,th in enumerate(self.bad_ths): + out['bad@{:.1f}'.format(th)] = (float(self.agg_Nbad[i]) / self.agg_N).item() * 100.0 + self._metrics = out + + def get_results(self): + self._compute_metrics() # to avoid recompute them multiple times + return self._metrics + +class FlowDatasetMetrics(nn.Module): + + def __init__(self): + super().__init__() + self.bad_ths = [0.5,1,3,5] + self.speed_ths = [(0,10),(10,40),(40,torch.inf)] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self.agg_EPEspeed = [torch.tensor(0.0) for _ in self.speed_ths] # EPE per speed bin so far + self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far + self._metrics = None + self.pairname_results = {} + + def add_batch(self, predictions, gt): + assert predictions.size(1)==2, predictions.size() + assert gt.size(1)==2, gt.size() + if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ... + L1err = torch.minimum( torch.minimum( torch.minimum( + torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1), + torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1)) + L2err = torch.minimum( torch.minimum( torch.minimum( + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]-predictions),dim=1)), + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]-predictions),dim=1))), + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]-predictions),dim=1))), + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]-predictions),dim=1))) + valid = torch.isfinite(L1err) + gtspeed = (torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]),dim=1)) +\ + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]),dim=1)) ) / 4.0 # let's just average them + else: + valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt-predictions),dim=1) + L2err = torch.sqrt(torch.sum(torch.square(gt-predictions),dim=1)) + gtspeed = torch.sqrt(torch.sum(torch.square(gt),dim=1)) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew + self.agg_L2err = float(self.agg_N)/Nnew * self.agg_L2err + L2err[valid].mean().cpu() * float(N)/Nnew + self.agg_N = Nnew + for i,th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L2err[valid]>th).sum().cpu() + for i,(th1,th2) in enumerate(self.speed_ths): + vv = (gtspeed[valid]>=th1) * (gtspeed[valid] don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len(self.pairnames) # each pairname is typically of the form (str, int1, int2) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + img1name = self.pairname_to_img1name(pairname) + img2name = self.pairname_to_img2name(pairname) + flowname = self.pairname_to_flowname(pairname) if self.pairname_to_flowname is not None else None + + # load images and disparities + img1 = _read_img(img1name) + img2 = _read_img(img2name) + flow = self.load_flow(flowname) if flowname is not None else None + + # apply augmentations + if self.augmentor is not None: + img1, img2, flow = self.augmentor(img1, img2, flow, self.name) + + if self.totensor: + img1 = img_to_tensor(img1) + img2 = img_to_tensor(img2) + if flow is not None: + flow = flow_to_tensor(flow) + else: + flow = torch.tensor([]) # to allow dataloader batching with default collate_gn + pairname = str(pairname) # transform potential tuple to str to be able to batch it + + return img1, img2, flow, pairname + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f'{self.__class__.__name__}_{self.split}' + + def __repr__(self): + s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' + if self.rmul==1: + s+=f'\n\tnum pairs: {len(self.pairnames)}' + else: + s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name+'.pkl') + if osp.isfile(cache_file): + with open(cache_file, 'rb') as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, 'wb') as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + +class TartanAirDataset(FlowDataset): + + def _prepare_data(self): + self.name = "TartanAir" + self._set_root() + assert self.split in ['train'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[2])) + self.pairname_to_flowname = lambda pairname: osp.join(self.root, pairname[0], 'flow/{:06d}_{:06d}_flow.npy'.format(pairname[1],pairname[2])) + self.pairname_to_str = lambda pairname: os.path.join(pairname[0][pairname[0].find('/')+1:], '{:06d}_{:06d}'.format(pairname[1], pairname[2])) + self.load_flow = _read_numpy_flow + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + pairs = [(osp.join(s,s,difficulty,Pxxx),int(a[:6]),int(a[:6])+1) for s in seqs for difficulty in ['Easy','Hard'] for Pxxx in sorted(os.listdir(osp.join(self.root,s,s,difficulty))) for a in sorted(os.listdir(osp.join(self.root,s,s,difficulty,Pxxx,'image_left/')))[:-1]] + assert len(pairs)==306268, "incorrect parsing of pairs in TartanAir" + tosave = {'train': pairs} + return tosave + +class FlyingChairsDataset(FlowDataset): + + def _prepare_data(self): + self.name = "FlyingChairs" + self._set_root() + assert self.split in ['train','val'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, 'data', pairname+'_img1.ppm') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, 'data', pairname+'_img2.ppm') + self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'data', pairname+'_flow.flo') + self.pairname_to_str = lambda pairname: pairname + self.load_flow = _read_flo_file + + def _build_cache(self): + split_file = osp.join(self.root, 'chairs_split.txt') + split_list = np.loadtxt(split_file, dtype=np.int32) + trainpairs = ['{:05d}'.format(i) for i in np.where(split_list==1)[0]+1] + valpairs = ['{:05d}'.format(i) for i in np.where(split_list==2)[0]+1] + assert len(trainpairs)==22232 and len(valpairs)==640, "incorrect parsing of pairs in MPI-Sintel" + tosave = {'train': trainpairs, 'val': valpairs} + return tosave + +class FlyingThingsDataset(FlowDataset): + + def _prepare_data(self): + self.name = "FlyingThings" + self._set_root() + assert self.split in [f'{set_}_{pass_}pass{camstr}' for set_ in ['train','test','test1024'] for camstr in ['','_rightcam'] for pass_ in ['clean','final','all']] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[2])) + self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'optical_flow', pairname[0], 'OpticalFlowInto{f:s}_{i:04d}_{c:s}.pfm'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' )) + self.pairname_to_str = lambda pairname: os.path.join(pairname[3]+'pass', pairname[0], 'Into{f:s}_{i:04d}_{c:s}'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' )) + self.load_flow = _read_pfm_flow + + def _build_cache(self): + tosave = {} + # train and test splits for the different passes + for set_ in ['train', 'test']: + sroot = osp.join(self.root, 'optical_flow', set_.upper()) + fname_to_i = lambda f: int(f[len('OpticalFlowIntoFuture_'):-len('_L.pfm')]) + pp = [(osp.join(set_.upper(), d, s, 'into_future/left'),fname_to_i(fname)) for d in sorted(os.listdir(sroot)) for s in sorted(os.listdir(osp.join(sroot,d))) for fname in sorted(os.listdir(osp.join(sroot,d, s, 'into_future/left')))[:-1]] + pairs = [(a,i,i+1) for a,i in pp] + pairs += [(a.replace('into_future','into_past'),i+1,i) for a,i in pp] + assert len(pairs)=={'train': 40302, 'test': 7866}[set_], "incorrect parsing of pairs Flying Things" + for cam in ['left','right']: + camstr = '' if cam=='left' else f'_{cam}cam' + for pass_ in ['final', 'clean']: + tosave[f'{set_}_{pass_}pass{camstr}'] = [(a.replace('left',cam),i,j,pass_) for a,i,j in pairs] + tosave[f'{set_}_allpass{camstr}'] = tosave[f'{set_}_cleanpass{camstr}'] + tosave[f'{set_}_finalpass{camstr}'] + # test1024: this is the same split as unimatch 'validation' split + # see https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/datasets.py#L229 + test1024_nsamples = 1024 + alltest_nsamples = len(tosave['test_cleanpass']) # 7866 + stride = alltest_nsamples // test1024_nsamples + remove = alltest_nsamples % test1024_nsamples + for cam in ['left','right']: + camstr = '' if cam=='left' else f'_{cam}cam' + for pass_ in ['final','clean']: + tosave[f'test1024_{pass_}pass{camstr}'] = sorted(tosave[f'test_{pass_}pass{camstr}'])[:-remove][::stride] # warning, it was not sorted before + assert len(tosave['test1024_cleanpass'])==1024, "incorrect parsing of pairs in Flying Things" + tosave[f'test1024_allpass{camstr}'] = tosave[f'test1024_cleanpass{camstr}'] + tosave[f'test1024_finalpass{camstr}'] + return tosave + + +class MPISintelDataset(FlowDataset): + + def _prepare_data(self): + self.name = "MPISintel" + self._set_root() + assert self.split in [s+'_'+p for s in ['train','test','subval','subtrain'] for p in ['cleanpass','finalpass','allpass']] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1]+1)) + self.pairname_to_flowname = lambda pairname: None if pairname[0].startswith('test/') else osp.join(self.root, pairname[0].replace('/clean/','/flow/').replace('/final/','/flow/'), 'frame_{:04d}.flo'.format(pairname[1])) + self.pairname_to_str = lambda pairname: osp.join(pairname[0], 'frame_{:04d}'.format(pairname[1])) + self.load_flow = _read_flo_file + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root+'training/clean')) + trainpairs = [ (osp.join('training/clean', s),i) for s in trainseqs for i in range(1, len(os.listdir(self.root+'training/clean/'+s)))] + subvalseqs = ['temple_2','temple_3'] + subtrainseqs = [s for s in trainseqs if s not in subvalseqs] + subvalpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subvalseqs)] + subtrainpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subtrainseqs)] + testseqs = sorted(os.listdir(self.root+'test/clean')) + testpairs = [ (osp.join('test/clean', s),i) for s in testseqs for i in range(1, len(os.listdir(self.root+'test/clean/'+s)))] + assert len(trainpairs)==1041 and len(testpairs)==552 and len(subvalpairs)==98 and len(subtrainpairs)==943, "incorrect parsing of pairs in MPI-Sintel" + tosave = {} + tosave['train_cleanpass'] = trainpairs + tosave['test_cleanpass'] = testpairs + tosave['subval_cleanpass'] = subvalpairs + tosave['subtrain_cleanpass'] = subtrainpairs + for t in ['train','test','subval','subtrain']: + tosave[t+'_finalpass'] = [(p.replace('/clean/','/final/'),i) for p,i in tosave[t+'_cleanpass']] + tosave[t+'_allpass'] = tosave[t+'_cleanpass'] + tosave[t+'_finalpass'] + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, _time): + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, 'submission', self.pairname_to_str(pairname)+'.flo') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowFile(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split == 'test_allpass' + bundle_exe = "/nfs/data/ffs-3d/datasets/StereoFlow/MPI-Sintel/bundler/linux-x64/bundler" # eg + if os.path.isfile(bundle_exe): + cmd = f'{bundle_exe} "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at: "{outdir}/submission/bundled.lzma"') + else: + print('Could not find bundler executable for submission.') + print('Please download it and run:') + print(f' "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"') + +class SpringDataset(FlowDataset): + + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ['train','test','subtrain','subval'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4]+(1 if pairname[2]=='FW' else -1))) + self.pairname_to_flowname = lambda pairname: None if pairname[0]=='test' else osp.join(self.root, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5') + self.pairname_to_str = lambda pairname: osp.join(pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}') + self.load_flow = _read_hdf5_flow + + def _build_cache(self): + # train + trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) + trainpairs = [] + for leftright in ['left','right']: + for fwbw in ['FW','BW']: + trainpairs += [('train',s,fwbw,leftright,int(f[len(f'flow_{fwbw}_{leftright}_'):-len('.flo5')])) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,f'flow_{fwbw}_{leftright}')))] + # test + testseqs = sorted(os.listdir( osp.join(self.root,'test'))) + testpairs = [] + for leftright in ['left','right']: + testpairs += [('test',s,'FW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]] + testpairs += [('test',s,'BW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])+1) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]] + # subtrain / subval + subtrainpairs = [p for p in trainpairs if p[1]!='0041'] + subvalpairs = [p for p in trainpairs if p[1]=='0041'] + assert len(trainpairs)==19852 and len(testpairs)==3960 and len(subtrainpairs)==19472 and len(subvalpairs)==380, "incorrect parsing of pairs in Spring" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + assert prediction.dtype==np.float32 + outfile = osp.join(outdir, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlo5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + exe = "{self.root}/flow_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/test/flow_submission.hdf5') + else: + print('Could not find flow_subsampling executable for submission.') + print('Please download it and run:') + print(f'cd "{outdir}/test"; .') + + +class Kitti12Dataset(FlowDataset): + + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ['train','test'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png') + self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/flow_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] + testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] + assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" + tosave = {'train': trainseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti12_flow_results.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti12_flow_results.zip') + + +class Kitti15Dataset(FlowDataset): + + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ['train','subtrain','subval','test'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png') + self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/flow_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] + subtrainseqs = trainseqs[:-10] + subvalseqs = trainseqs[-10:] + testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] + assert len(trainseqs)==200 and len(subtrainseqs)==190 and len(subvalseqs)==10 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" + tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, 'flow', pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti15_flow_results.zip" flow' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti15_flow_results.zip') + + +import cv2 +def _read_numpy_flow(filename): + return np.load(filename) + +def _read_pfm_flow(filename): + f, _ = _read_pfm(filename) + assert np.all(f[:,:,2]==0.0) + return np.ascontiguousarray(f[:,:,:2]) + +TAG_FLOAT = 202021.25 # tag to check the sanity of the file +TAG_STRING = 'PIEH' # string containing the tag +MIN_WIDTH = 1 +MAX_WIDTH = 99999 +MIN_HEIGHT = 1 +MAX_HEIGHT = 99999 +def readFlowFile(filename): + """ + readFlowFile() reads a flow file into a 2-band np.array. + if does not exist, an IOError is raised. + if does not finish by '.flo' or the tag, the width, the height or the file's size is illegal, an Expcetion is raised. + ---- PARAMETERS ---- + filename: string containg the name of the file to read a flow + ---- OUTPUTS ---- + a np.array of dimension (height x width x 2) containing the flow of type 'float32' + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception("readFlowFile({:s}): filename must finish with '.flo'".format(filename)) + + # open the file and read it + with open(filename,'rb') as f: + # check tag + tag = struct.unpack('f',f.read(4))[0] + if tag != TAG_FLOAT: + raise Exception("flow_utils.readFlowFile({:s}): wrong tag".format(filename)) + # read dimension + w,h = struct.unpack('ii',f.read(8)) + if w < MIN_WIDTH or w > MAX_WIDTH: + raise Exception("flow_utils.readFlowFile({:s}: illegal width {:d}".format(filename,w)) + if h < MIN_HEIGHT or h > MAX_HEIGHT: + raise Exception("flow_utils.readFlowFile({:s}: illegal height {:d}".format(filename,h)) + flow = np.fromfile(f,'float32') + if not flow.shape == (h*w*2,): + raise Exception("flow_utils.readFlowFile({:s}: illegal size of the file".format(filename)) + flow.shape = (h,w,2) + return flow + +def writeFlowFile(flow,filename): + """ + writeFlowFile(flow,) write flow to the file . + if does not exist, an IOError is raised. + if does not finish with '.flo' or the flow has not 2 bands, an Exception is raised. + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to write + filename: string containg the name of the file to write a flow + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception("flow_utils.writeFlowFile(,{:s}): filename must finish with '.flo'".format(filename)) + + if not flow.shape[2:] == (2,): + raise Exception("flow_utils.writeFlowFile(,{:s}): must have 2 bands".format(filename)) + + + # open the file and write it + with open(filename,'wb') as f: + # write TAG + f.write( TAG_STRING.encode('utf-8') ) + # write dimension + f.write( struct.pack('ii',flow.shape[1],flow.shape[0]) ) + # write the flow + + flow.astype(np.float32).tofile(f) + +_read_flo_file = readFlowFile + +def _read_kitti_flow(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + valid = flow[:, :, 2]>0 + flow = flow[:, :, :2] + flow = (flow - 2 ** 15) / 64.0 + flow[~valid,0] = np.inf + flow[~valid,1] = np.inf + return flow +_read_hd1k_flow = _read_kitti_flow + + +def writeFlowKitti(filename, uv): + uv = 64.0 * uv + 2 ** 15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + +def writeFlo5File(flow, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5) + +def _read_hdf5_flow(filename): + flow = np.asarray(h5py.File(filename)['flow']) + flow[np.isnan(flow)] = np.inf # make invalid values as +inf + return flow.astype(np.float32) + +# flow visualization +RY = 15 +YG = 6 +GC = 4 +CB = 11 +BM = 13 +MR = 6 +UNKNOWN_THRESH = 1e9 + +def colorTest(): + """ + flow_utils.colorTest(): display an example of image showing the color encoding scheme + """ + import matplotlib.pylab as plt + truerange = 1 + h,w = 151,151 + trange = truerange*1.04 + s2 = round(h/2) + x,y = np.meshgrid(range(w),range(h)) + u = x*trange/s2-trange + v = y*trange/s2-trange + img = _computeColor(np.concatenate((u[:,:,np.newaxis],v[:,:,np.newaxis]),2)/trange/np.sqrt(2)) + plt.imshow(img) + plt.axis('off') + plt.axhline(round(h/2),color='k') + plt.axvline(round(w/2),color='k') + +def flowToColor(flow, maxflow=None, maxmaxflow=None, saturate=False): + """ + flow_utils.flowToColor(flow): return a color code flow field, normalized based on the maximum l2-norm of the flow + flow_utils.flowToColor(flow,maxflow): return a color code flow field, normalized by maxflow + ---- PARAMETERS ---- + flow: flow to display of shape (height x width x 2) + maxflow (default:None): if given, normalize the flow by its value, otherwise by the flow norm + maxmaxflow (default:None): if given, normalize the flow by the max of its value and the flow norm + ---- OUTPUT ---- + an np.array of shape (height x width x 3) of type uint8 containing a color code of the flow + """ + h,w,n = flow.shape + # check size of flow + assert n == 2, "flow_utils.flowToColor(flow): flow must have 2 bands" + # fix unknown flow + unknown_idx = np.max(np.abs(flow),2)>UNKNOWN_THRESH + flow[unknown_idx] = 0.0 + # compute max flow if needed + if maxflow is None: + maxflow = flowMaxNorm(flow) + if maxmaxflow is not None: + maxflow = min(maxmaxflow, maxflow) + # normalize flow + eps = np.spacing(1) # minimum positive float value to avoid division by 0 + # compute the flow + img = _computeColor(flow/(maxflow+eps), saturate=saturate) + # put black pixels in unknown location + img[ np.tile( unknown_idx[:,:,np.newaxis],[1,1,3]) ] = 0.0 + return img + +def flowMaxNorm(flow): + """ + flow_utils.flowMaxNorm(flow): return the maximum of the l2-norm of the given flow + ---- PARAMETERS ---- + flow: the flow + + ---- OUTPUT ---- + a float containing the maximum of the l2-norm of the flow + """ + return np.max( np.sqrt( np.sum( np.square( flow ) , 2) ) ) + +def _computeColor(flow, saturate=True): + """ + flow_utils._computeColor(flow): compute color codes for the flow field flow + + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to display + ---- OUTPUTS ---- + an np.array of dimension (height x width x 3) containing the color conversion of the flow + """ + # set nan to 0 + nanidx = np.isnan(flow[:,:,0]) + flow[nanidx] = 0.0 + + # colorwheel + ncols = RY + YG + GC + CB + BM + MR + nchans = 3 + colorwheel = np.zeros((ncols,nchans),'uint8') + col = 0; + #RY + colorwheel[:RY,0] = 255 + colorwheel[:RY,1] = [(255*i) // RY for i in range(RY)] + col += RY + # YG + colorwheel[col:col+YG,0] = [255 - (255*i) // YG for i in range(YG)] + colorwheel[col:col+YG,1] = 255 + col += YG + # GC + colorwheel[col:col+GC,1] = 255 + colorwheel[col:col+GC,2] = [(255*i) // GC for i in range(GC)] + col += GC + # CB + colorwheel[col:col+CB,1] = [255 - (255*i) // CB for i in range(CB)] + colorwheel[col:col+CB,2] = 255 + col += CB + # BM + colorwheel[col:col+BM,0] = [(255*i) // BM for i in range(BM)] + colorwheel[col:col+BM,2] = 255 + col += BM + # MR + colorwheel[col:col+MR,0] = 255 + colorwheel[col:col+MR,2] = [255 - (255*i) // MR for i in range(MR)] + + # compute utility variables + rad = np.sqrt( np.sum( np.square(flow) , 2) ) # magnitude + a = np.arctan2( -flow[:,:,1] , -flow[:,:,0]) / np.pi # angle + fk = (a+1)/2 * (ncols-1) # map [-1,1] to [0,ncols-1] + k0 = np.floor(fk).astype('int') + k1 = k0+1 + k1[k1==ncols] = 0 + f = fk-k0 + + if not saturate: + rad = np.minimum(rad,1) + + # compute the image + img = np.zeros( (flow.shape[0],flow.shape[1],nchans), 'uint8' ) + for i in range(nchans): + tmp = colorwheel[:,i].astype('float') + col0 = tmp[k0]/255 + col1 = tmp[k1]/255 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1-rad[idx]*(1-col[idx]) # increase saturation with radius + col[~idx] *= 0.75 # out of range + img[:,:,i] = (255*col*(1-nanidx.astype('float'))).astype('uint8') + + return img + +# flow dataset getter + +def get_train_dataset_flow(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace('(','Dataset(') + if augmentor: + dataset_str = dataset_str.replace(')',', augmentor=True)') + if crop_size is not None: + dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) + return eval(dataset_str) + +def get_test_datasets_flow(dataset_str): + dataset_str = dataset_str.replace('(','Dataset(') + return [eval(s) for s in dataset_str.split('+')] \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/stereoflow/datasets_stereo.py b/imcui/third_party/dust3r/croco/stereoflow/datasets_stereo.py new file mode 100644 index 0000000000000000000000000000000000000000..dbdf841a6650afa71ae5782702902c79eba31a5c --- /dev/null +++ b/imcui/third_party/dust3r/croco/stereoflow/datasets_stereo.py @@ -0,0 +1,674 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Dataset structure for stereo +# -------------------------------------------------------- + +import sys, os +import os.path as osp +import pickle +import numpy as np +from PIL import Image +import json +import h5py +from glob import glob +import cv2 + +import torch +from torch.utils import data + +from .augmentor import StereoAugmentor + + + +dataset_to_root = { + 'CREStereo': './data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/', + 'SceneFlow': './data/stereoflow//SceneFlow/', + 'ETH3DLowRes': './data/stereoflow/eth3d_lowres/', + 'Booster': './data/stereoflow/booster_gt/', + 'Middlebury2021': './data/stereoflow/middlebury/2021/data/', + 'Middlebury2014': './data/stereoflow/middlebury/2014/', + 'Middlebury2006': './data/stereoflow/middlebury/2006/', + 'Middlebury2005': './data/stereoflow/middlebury/2005/train/', + 'MiddleburyEval3': './data/stereoflow/middlebury/MiddEval3/', + 'Spring': './data/stereoflow/spring/', + 'Kitti15': './data/stereoflow/kitti-stereo-2015/', + 'Kitti12': './data/stereoflow/kitti-stereo-2012/', +} +cache_dir = "./data/stereoflow/datasets_stereo_cache/" + + +in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) +in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) +def img_to_tensor(img): + img = torch.from_numpy(img).permute(2, 0, 1).float() / 255. + img = (img-in1k_mean)/in1k_std + return img +def disp_to_tensor(disp): + return torch.from_numpy(disp)[None,:,:] + +class StereoDataset(data.Dataset): + + def __init__(self, split, augmentor=False, crop_size=None, totensor=True): + self.split = split + if not augmentor: assert crop_size is None + if crop_size: assert augmentor + self.crop_size = crop_size + self.augmentor_str = augmentor + self.augmentor = StereoAugmentor(crop_size) if augmentor else None + self.totensor = totensor + self.rmul = 1 # keep track of rmul + self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len(self.pairnames) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + Limgname = self.pairname_to_Limgname(pairname) + Rimgname = self.pairname_to_Rimgname(pairname) + Ldispname = self.pairname_to_Ldispname(pairname) if self.pairname_to_Ldispname is not None else None + + # load images and disparities + Limg = _read_img(Limgname) + Rimg = _read_img(Rimgname) + disp = self.load_disparity(Ldispname) if Ldispname is not None else None + + # sanity check + if disp is not None: assert np.all(disp>0) or self.name=="Spring", (self.name, pairname, Ldispname) + + # apply augmentations + if self.augmentor is not None: + Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) + + if self.totensor: + Limg = img_to_tensor(Limg) + Rimg = img_to_tensor(Rimg) + if disp is None: + disp = torch.tensor([]) # to allow dataloader batching with default collate_gn + else: + disp = disp_to_tensor(disp) + + return Limg, Rimg, disp, str(pairname) + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f'{self.__class__.__name__}_{self.split}' + + def __repr__(self): + s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' + if self.rmul==1: + s+=f'\n\tnum pairs: {len(self.pairnames)}' + else: + s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name+'.pkl') + if osp.isfile(cache_file): + with open(cache_file, 'rb') as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, 'wb') as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + +class CREStereoDataset(StereoDataset): + + def _prepare_data(self): + self.name = 'CREStereo' + self._set_root() + assert self.split in ['train'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_left.jpg') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'_right.jpg') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname+'_left.disp.png') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_crestereo_disp + + + def _build_cache(self): + allpairs = [s+'/'+f[:-len('_left.jpg')] for s in sorted(os.listdir(self.root)) for f in sorted(os.listdir(self.root+'/'+s)) if f.endswith('_left.jpg')] + assert len(allpairs)==200000, "incorrect parsing of pairs in CreStereo" + tosave = {'train': allpairs} + return tosave + +class SceneFlowDataset(StereoDataset): + + def _prepare_data(self): + self.name = "SceneFlow" + self._set_root() + assert self.split in ['train_finalpass','train_cleanpass','train_allpass','test_finalpass','test_cleanpass','test_allpass','test1of100_cleanpass','test1of100_finalpass'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/left/','/right/') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname).replace('/frames_finalpass/','/disparity/').replace('/frames_cleanpass/','/disparity/')[:-4]+'.pfm' + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_sceneflow_disp + + def _build_cache(self): + trainpairs = [] + # driving + pairs = sorted(glob(self.root+'Driving/frames_finalpass/*/*/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # monkaa + pairs = sorted(glob(self.root+'Monkaa/frames_finalpass/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # flyingthings + pairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" + testpairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TEST/*/*/left/*.png')) + testpairs = list(map(lambda x: x[len(self.root):], testpairs)) + assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" + test1of100pairs = testpairs[::100] + assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" + # all + tosave = {'train_finalpass': trainpairs, + 'train_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), trainpairs)), + 'test_finalpass': testpairs, + 'test_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), testpairs)), + 'test1of100_finalpass': test1of100pairs, + 'test1of100_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), test1of100pairs)), + } + tosave['train_allpass'] = tosave['train_finalpass']+tosave['train_cleanpass'] + tosave['test_allpass'] = tosave['test_finalpass']+tosave['test_cleanpass'] + return tosave + +class Md21Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2021" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/im0','/im1')) + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp0.pfm') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + #trainpairs += [s+'/im0.png'] # we should remove it, it is included as such in other lightings + trainpairs += [s+'/ambient/'+b+'/'+a for b in sorted(os.listdir(osp.join(self.root,s,'ambient'))) for a in sorted(os.listdir(osp.join(self.root,s,'ambient',b))) if a.startswith('im0')] + assert len(trainpairs)==355 + subtrainpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[:-2])] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[-2:])] + assert len(subtrainpairs)==335 and len(subvalpairs)==20, "incorrect parsing of pairs in Middlebury 2021" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md14Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2014" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'disp0.pfm') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + trainpairs += [s+'/im1.png',s+'/im1E.png',s+'/im1L.png'] + assert len(trainpairs)==138 + valseqs = ['Umbrella-imperfect','Vintage-perfect'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==132 and len(subvalpairs)==6, "incorrect parsing of pairs in Middlebury 2014" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md06Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2006" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') + self.load_disparity = _read_middlebury20052006_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ['Illum1','Illum2','Illum3']: + for e in ['Exp0','Exp1','Exp2']: + trainpairs.append(osp.join(s,i,e,'view1.png')) + assert len(trainpairs)==189 + valseqs = ['Rocks1','Wood2'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==171 and len(subvalpairs)==18, "incorrect parsing of pairs in Middlebury 2006" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md05Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2005" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury20052006_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ['Illum1','Illum2','Illum3']: + for e in ['Exp0','Exp1','Exp2']: + trainpairs.append(osp.join(s,i,e,'view1.png')) + assert len(trainpairs)==54, "incorrect parsing of pairs in Middlebury 2005" + valseqs = ['Reindeer'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==45 and len(subvalpairs)==9, "incorrect parsing of pairs in Middlebury 2005" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class MdEval3Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "MiddleburyEval3" + self._set_root() + assert self.split in [s+'_'+r for s in ['train','subtrain','subval','test','all'] for r in ['full','half','quarter']] + if self.split.endswith('_full'): + self.root = self.root.replace('/MiddEval3','/MiddEval3_F') + elif self.split.endswith('_half'): + self.root = self.root.replace('/MiddEval3','/MiddEval3_H') + else: + assert self.split.endswith('_quarter') + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') + self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname, 'disp0GT.pfm') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_middlebury_disp + # for submission only + self.submission_methodname = "CroCo-Stereo" + self.submission_sresolution = 'F' if self.split.endswith('_full') else ('H' if self.split.endswith('_half') else 'Q') + + def _build_cache(self): + trainpairs = ['train/'+s for s in sorted(os.listdir(self.root+'train/'))] + testpairs = ['test/'+s for s in sorted(os.listdir(self.root+'test/'))] + subvalpairs = trainpairs[-1:] + subtrainpairs = trainpairs[:-1] + allpairs = trainpairs+testpairs + assert len(trainpairs)==15 and len(testpairs)==15 and len(subvalpairs)==1 and len(subtrainpairs)==14 and len(allpairs)==30, "incorrect parsing of pairs in Middlebury Eval v3" + tosave = {} + for r in ['full','half','quarter']: + tosave.update(**{'train_'+r: trainpairs, 'subtrain_'+r: subtrainpairs, 'subval_'+r: subvalpairs, 'test_'+r: testpairs, 'all_'+r: allpairs}) + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname.split('/')[0].replace('train','training')+self.submission_sresolution, pairname.split('/')[1], 'disp0'+self.submission_methodname+'.pfm') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = os.path.join( os.path.dirname(outfile), "time"+self.submission_methodname+'.txt') + with open(timefile, 'w') as fid: + fid.write(str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/{self.submission_methodname}.zip') + +class ETH3DLowResDataset(StereoDataset): + + def _prepare_data(self): + self.name = "ETH3DLowRes" + self._set_root() + assert self.split in ['train','test','subtrain','subval','all'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: None if pairname.startswith('test/') else osp.join(self.root, pairname.replace('train/','train_gt/'), 'disp0GT.pfm') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_eth3d_disp + self.has_constant_resolution = False + + def _build_cache(self): + trainpairs = ['train/' + s for s in sorted(os.listdir(self.root+'train/'))] + testpairs = ['test/' + s for s in sorted(os.listdir(self.root+'test/'))] + assert len(trainpairs) == 27 and len(testpairs) == 20, "incorrect parsing of pairs in ETH3D Low Res" + subvalpairs = ['train/delivery_area_3s','train/electro_3l','train/playground_3l'] + assert all(p in trainpairs for p in subvalpairs) + subtrainpairs = [p for p in trainpairs if not p in subvalpairs] + assert len(subvalpairs)==3 and len(subtrainpairs)==24, "incorrect parsing of pairs in ETH3D Low Res" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs, 'all': trainpairs+testpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, 'low_res_two_view', pairname.split('/')[1]+'.pfm') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = outfile[:-4]+'.txt' + with open(timefile, 'w') as fid: + fid.write('runtime '+str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip') + +class BoosterDataset(StereoDataset): + + def _prepare_data(self): + self.name = "Booster" + self._set_root() + assert self.split in ['train_balanced','test_balanced','subtrain_balanced','subval_balanced'] # we use only the balanced version + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/camera_00/','/camera_02/') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), '../disp_00.npy') # same images with different colors, same gt per sequence + self.pairname_to_str = lambda pairname: pairname[:-4].replace('/camera_00/','/') + self.load_disparity = _read_booster_disp + + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root+'train/balanced')) + trainpairs = ['train/balanced/'+s+'/camera_00/'+imname for s in trainseqs for imname in sorted(os.listdir(self.root+'train/balanced/'+s+'/camera_00/'))] + testpairs = ['test/balanced/'+s+'/camera_00/'+imname for s in sorted(os.listdir(self.root+'test/balanced')) for imname in sorted(os.listdir(self.root+'test/balanced/'+s+'/camera_00/'))] + assert len(trainpairs) == 228 and len(testpairs) == 191 + subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] + subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] + # warning: if we do validation split, we should split scenes!!! + tosave = {'train_balanced': trainpairs, 'test_balanced': testpairs, 'subtrain_balanced': subtrainpairs, 'subval_balanced': subvalpairs,} + return tosave + +class SpringDataset(StereoDataset): + + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ['train', 'test', 'subtrain', 'subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'.png').replace('frame_right','').replace('frame_left','frame_right').replace('','frame_left') + self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_hdf5_disp + + def _build_cache(self): + trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) + trainpairs = [osp.join('train',s,'frame_left',f[:-4]) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,'frame_left')))] + testseqs = sorted(os.listdir( osp.join(self.root,'test'))) + testpairs = [osp.join('test',s,'frame_left',f[:-4]) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,'frame_left')))] + testpairs += [p.replace('frame_left','frame_right') for p in testpairs] + """maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" + subtrainpairs = [p for p in trainpairs if p.split('/')[1]!='0041'] + subvalpairs = [p for p in trainpairs if p.split('/')[1]=='0041'] + assert len(trainpairs)==5000 and len(testpairs)==2000 and len(subtrainpairs)==4904 and len(subvalpairs)==96, "incorrect parsing of pairs in Spring" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeDsp5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + exe = "{self.root}/disp1_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + else: + print('Could not find disp1_subsampling executable for submission.') + print('Please download it and run:') + print(f'cd "{outdir}/test"; .') + +class Kitti12Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ['train','test'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/colored_1/')+'_10.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/disp_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] + testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] + assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" + tosave = {'train': trainseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype('uint16') + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti12_results.zip') + +class Kitti15Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ['train','subtrain','subval','test'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/image_3/')+'_10.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/disp_occ_0/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] + subtrainseqs = trainseqs[:-5] + subvalseqs = trainseqs[-5:] + testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] + assert len(trainseqs)==200 and len(subtrainseqs)==195 and len(subvalseqs)==5 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" + tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, 'disp_0', pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype('uint16') + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti15_results.zip') + + +### auxiliary functions + +def _read_img(filename): + # convert to RGB for scene flow finalpass data + img = np.asarray(Image.open(filename).convert('RGB')) + return img + +def _read_booster_disp(filename): + disp = np.load(filename) + disp[disp==0.0] = np.inf + return disp + +def _read_png_disp(filename, coef=1.0): + disp = np.asarray(Image.open(filename)) + disp = disp.astype(np.float32) / coef + disp[disp==0.0] = np.inf + return disp + +def _read_pfm_disp(filename): + disp = np.ascontiguousarray(_read_pfm(filename)[0]) + disp[disp<=0] = np.inf # eg /nfs/data/ffs-3d/datasets/middlebury/2014/Shopvac-imperfect/disp0.pfm + return disp + +def _read_npy_disp(filename): + return np.load(filename) + +def _read_crestereo_disp(filename): return _read_png_disp(filename, coef=32.0) +def _read_middlebury20052006_disp(filename): return _read_png_disp(filename, coef=1.0) +def _read_kitti_disp(filename): return _read_png_disp(filename, coef=256.0) +_read_sceneflow_disp = _read_pfm_disp +_read_eth3d_disp = _read_pfm_disp +_read_middlebury_disp = _read_pfm_disp +_read_carla_disp = _read_pfm_disp +_read_tartanair_disp = _read_npy_disp + +def _read_hdf5_disp(filename): + disp = np.asarray(h5py.File(filename)['disparity']) + disp[np.isnan(disp)] = np.inf # make invalid values as +inf + #disp[disp==0.0] = np.inf # make invalid values as +inf + return disp.astype(np.float32) + +import re +def _read_pfm(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + +def writeDsp5File(disp, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) + + +# disp visualization + +def vis_disparity(disp, m=None, M=None): + if m is None: m = disp.min() + if M is None: M = disp.max() + disp_vis = (disp - m) / (M-m) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + return disp_vis + +# dataset getter + +def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace('(','Dataset(') + if augmentor: + dataset_str = dataset_str.replace(')',', augmentor=True)') + if crop_size is not None: + dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) + return eval(dataset_str) + +def get_test_datasets_stereo(dataset_str): + dataset_str = dataset_str.replace('(','Dataset(') + return [eval(s) for s in dataset_str.split('+')] \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/stereoflow/engine.py b/imcui/third_party/dust3r/croco/stereoflow/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..c057346b99143bf6b9c4666a58215b2b91aca7a6 --- /dev/null +++ b/imcui/third_party/dust3r/croco/stereoflow/engine.py @@ -0,0 +1,280 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main function for training one epoch or testing +# -------------------------------------------------------- + +import math +import sys +from typing import Iterable +import numpy as np +import torch +import torchvision + +from utils import misc as misc + + +def split_prediction_conf(predictions, with_conf=False): + if not with_conf: + return predictions, None + conf = predictions[:,-1:,:,:] + predictions = predictions[:,:-1,:,:] + return predictions, conf + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + log_writer=None, print_freq = 20, + args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + + accum_iter = args.accum_iter + + optimizer.zero_grad() + + details = {} + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if args.img_per_epoch: + iter_per_epoch = args.img_per_epoch // args.batch_size + int(args.img_per_epoch % args.batch_size > 0) + assert len(data_loader) >= iter_per_epoch, 'Dataset is too small for so many iterations' + len_data_loader = iter_per_epoch + else: + len_data_loader, iter_per_epoch = len(data_loader), None + + for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_logger.log_every(data_loader, print_freq, header, max_iter=iter_per_epoch)): + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, data_iter_step / len_data_loader + epoch, args) + + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + prediction = model(image1, image2) + prediction, conf = split_prediction_conf(prediction, criterion.with_conf) + batch_metrics = metrics(prediction.detach(), gt) + loss = criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf) + + loss_value = loss.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + for k,v in batch_metrics.items(): + metric_logger.update(**{k: v.item()}) + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + #if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) + time_to_log = ((data_iter_step + 1) % (args.tboard_log_step * accum_iter) == 0 or data_iter_step == len_data_loader-1) + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and time_to_log: + epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) + # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. + log_writer.add_scalar('train/loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('lr', lr, epoch_1000x) + for k,v in batch_metrics.items(): + log_writer.add_scalar('train/'+k, v.item(), epoch_1000x) + + # gather the stats from all processes + #if args.distributed: metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def validate_one_epoch(model: torch.nn.Module, + criterion: torch.nn.Module, + metrics: torch.nn.Module, + data_loaders: list[Iterable], + device: torch.device, + epoch: int, + log_writer=None, + args=None): + + model.eval() + metric_loggers = [] + header = 'Epoch: [{}]'.format(epoch) + print_freq = 20 + + conf_mode = args.tile_conf_mode + crop = args.crop + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + results = {} + dnames = [] + image1, image2, gt, prediction = None, None, None, None + for didx, data_loader in enumerate(data_loaders): + dname = str(data_loader.dataset) + dnames.append(dname) + metric_loggers.append(misc.MetricLogger(delimiter=" ")) + for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_loggers[didx].log_every(data_loader, print_freq, header)): + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + if dname.startswith('Spring'): + assert gt.size(2)==image1.size(2)*2 and gt.size(3)==image1.size(3)*2 + gt = (gt[:,:,0::2,0::2] + gt[:,:,0::2,1::2] + gt[:,:,1::2,0::2] + gt[:,:,1::2,1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones + + with torch.inference_mode(): + prediction, tiled_loss, c = tiled_pred(model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf) + batch_metrics = metrics(prediction.detach(), gt) + loss = criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c) + loss_value = loss.item() + metric_loggers[didx].update(loss_tiled=tiled_loss.item()) + metric_loggers[didx].update(**{f'loss': loss_value}) + for k,v in batch_metrics.items(): + metric_loggers[didx].update(**{dname+'_' + k: v.item()}) + + results = {k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()} + if len(dnames)>1: + for k in batch_metrics.keys(): + results['AVG_'+k] = sum(results[dname+'_'+k] for dname in dnames) / len(dnames) + + if log_writer is not None : + epoch_1000x = int((1 + epoch) * 1000) + for k,v in results.items(): + log_writer.add_scalar('val/'+k, v, epoch_1000x) + + print("Averaged stats:", results) + return results + +import torch.nn.functional as F +def _resize_img(img, new_size): + return F.interpolate(img, size=new_size, mode='bicubic', align_corners=False) +def _resize_stereo_or_flow(data, new_size): + assert data.ndim==4 + assert data.size(1) in [1,2] + scale_x = new_size[1]/float(data.size(3)) + out = F.interpolate(data, size=new_size, mode='bicubic', align_corners=False) + out[:,0,:,:] *= scale_x + if out.size(1)==2: + scale_y = new_size[0]/float(data.size(2)) + out[:,1,:,:] *= scale_y + print(scale_x, new_size, data.shape) + return out + + +@torch.no_grad() +def tiled_pred(model, criterion, img1, img2, gt, + overlap=0.5, bad_crop_thr=0.05, + downscale=False, crop=512, ret='loss', + conf_mode='conf_expsigmoid_10_5', with_conf=False, + return_time=False): + + # for each image, we are going to run inference on many overlapping patches + # then, all predictions will be weighted-averaged + if gt is not None: + B, C, H, W = gt.shape + else: + B, _, H, W = img1.shape + C = model.head.num_channels-int(with_conf) + win_height, win_width = crop[0], crop[1] + + # upscale to be larger than the crop + do_change_scale = H= window and 0 <= overlap < 1, (total, window, overlap) + num_windows = 1 + int(np.ceil( (total - window) / ((1-overlap) * window) )) + offsets = np.linspace(0, total-window, num_windows).round().astype(int) + yield from (slice(x, x+window) for x in offsets) + +def _crop(img, sy, sx): + B, THREE, H, W = img.shape + if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: + return img[:,:,sy,sx] + l, r = max(0,-sx.start), max(0,sx.stop-W) + t, b = max(0,-sy.start), max(0,sy.stop-H) + img = torch.nn.functional.pad(img, (l,r,t,b), mode='constant') + return img[:, :, slice(sy.start+t,sy.stop+t), slice(sx.start+l,sx.stop+l)] \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/stereoflow/test.py b/imcui/third_party/dust3r/croco/stereoflow/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0248e56664c769752595af251e1eadcfa3a479d9 --- /dev/null +++ b/imcui/third_party/dust3r/croco/stereoflow/test.py @@ -0,0 +1,216 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main test function +# -------------------------------------------------------- + +import os +import argparse +import pickle +from PIL import Image +import numpy as np +from tqdm import tqdm + +import torch +from torch.utils.data import DataLoader + +import utils.misc as misc +from models.croco_downstream import CroCoDownstreamBinocular +from models.head_downstream import PixelwiseTaskWithDPT + +from stereoflow.criterion import * +from stereoflow.datasets_stereo import get_test_datasets_stereo +from stereoflow.datasets_flow import get_test_datasets_flow +from stereoflow.engine import tiled_pred + +from stereoflow.datasets_stereo import vis_disparity +from stereoflow.datasets_flow import flowToColor + +def get_args_parser(): + parser = argparse.ArgumentParser('Test CroCo models on stereo/flow', add_help=False) + # important argument + parser.add_argument('--model', required=True, type=str, help='Path to the model to evaluate') + parser.add_argument('--dataset', required=True, type=str, help="test dataset (there can be multiple dataset separated by a +)") + # tiling + parser.add_argument('--tile_conf_mode', type=str, default='', help='Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint') + parser.add_argument('--tile_overlap', type=float, default=0.7, help='overlap between tiles') + # save (it will automatically go to _/_) + parser.add_argument('--save', type=str, nargs='+', default=[], + help='what to save: \ + metrics (pickle file), \ + pred (raw prediction save as torch tensor), \ + visu (visualization in png of each prediction), \ + err10 (visualization in png of the error clamp at 10 for each prediction), \ + submission (submission file)') + # other (no impact) + parser.add_argument('--num_workers', default=4, type=int) + return parser + + +def _load_model_and_criterion(model_path, do_load_metrics, device): + print('loading model from', model_path) + assert os.path.isfile(model_path) + ckpt = torch.load(model_path, 'cpu') + + ckpt_args = ckpt['args'] + task = ckpt_args.task + tile_conf_mode = ckpt_args.tile_conf_mode + num_channels = {'stereo': 1, 'flow': 2}[task] + with_conf = eval(ckpt_args.criterion).with_conf + if with_conf: num_channels += 1 + print('head: PixelwiseTaskWithDPT()') + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + print('croco_args:', ckpt_args.croco_args) + model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) + msg = model.load_state_dict(ckpt['model'], strict=True) + model.eval() + model = model.to(device) + + if do_load_metrics: + if task=='stereo': + metrics = StereoDatasetMetrics().to(device) + else: + metrics = FlowDatasetMetrics().to(device) + else: + metrics = None + + return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode + + +def _save_batch(pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None): + + for i in range(len(pairnames)): + + pairname = eval(pairnames[i]) if pairnames[i].startswith('(') else pairnames[i] # unbatch pairname + fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) + os.makedirs(os.path.dirname(fname), exist_ok=True) + + predi = pred[i,...] + if gt is not None: gti = gt[i,...] + + if 'pred' in save: + torch.save(predi.squeeze(0).cpu(), fname+'_pred.pth') + + if 'visu' in save: + if task=='stereo': + disparity = predi.permute((1,2,0)).squeeze(2).cpu().numpy() + m,M = None + if gt is not None: + mask = torch.isfinite(gti) + m = gt[mask].min() + M = gt[mask].max() + img_disparity = vis_disparity(disparity, m=m, M=M) + Image.fromarray(img_disparity).save(fname+'_pred.png') + else: + # normalize flowToColor according to the maxnorm of gt (or prediction if not available) + flowNorm = torch.sqrt(torch.sum( (gti if gt is not None else predi)**2, dim=0)).max().item() + imgflow = flowToColor(predi.permute((1,2,0)).cpu().numpy(), maxflow=flowNorm) + Image.fromarray(imgflow).save(fname+'_pred.png') + + if 'err10' in save: + assert gt is not None + L2err = torch.sqrt(torch.sum( (gti-predi)**2, dim=0)) + valid = torch.isfinite(gti[0,:,:]) + L2err[~valid] = 0.0 + L2err = torch.clamp(L2err, max=10.0) + red = (L2err*255.0/10.0).to(dtype=torch.uint8)[:,:,None] + zer = torch.zeros_like(red) + imgerr = torch.cat( (red,zer,zer), dim=2).cpu().numpy() + Image.fromarray(imgerr).save(fname+'_err10.png') + + if 'submission' in save: + assert submission_dir is not None + predi_np = predi.permute(1,2,0).squeeze(2).cpu().numpy() # transform into HxWx2 for flow or HxW for stereo + dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) + +def main(args): + + # load the pretrained model and metrics + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + model, metrics, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion(args.model, 'metrics' in args.save, device) + if args.tile_conf_mode=='': args.tile_conf_mode = tile_conf_mode + + # load the datasets + datasets = (get_test_datasets_stereo if task=='stereo' else get_test_datasets_flow)(args.dataset) + dataloaders = [DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for dataset in datasets] + + # run + for i,dataloader in enumerate(dataloaders): + dataset = datasets[i] + dstr = args.dataset.split('+')[i] + + outdir = args.model+'_'+misc.filename(dstr) + if 'metrics' in args.save and len(args.save)==1: + fname = os.path.join(outdir, f'conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl') + if os.path.isfile(fname) and len(args.save)==1: + print(' metrics already compute in '+fname) + with open(fname, 'rb') as fid: + results = pickle.load(fid) + for k,v in results.items(): + print('{:s}: {:.3f}'.format(k, v)) + continue + + if 'submission' in args.save: + dirname = f'submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}' + submission_dir = os.path.join(outdir, dirname) + else: + submission_dir = None + + print('') + print('saving {:s} in {:s}'.format('+'.join(args.save), outdir)) + print(repr(dataset)) + + if metrics is not None: + metrics.reset() + + for data_iter_step, (image1, image2, gt, pairnames) in enumerate(tqdm(dataloader)): + + do_flip = (task=='stereo' and dstr.startswith('Spring') and any("right" in p for p in pairnames)) # we flip the images and will flip the prediction after as we assume img1 is on the left + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) if gt.numel()>0 else None # special case for test time + if do_flip: + assert all("right" in p for p in pairnames) + image1 = image1.flip(dims=[3]) # this is already the right frame, let's flip it + image2 = image2.flip(dims=[3]) + gt = gt # that is ok + + with torch.inference_mode(): + pred, _, _, time = tiled_pred(model, None, image1, image2, None if dataset.name=='Spring' else gt, conf_mode=args.tile_conf_mode, overlap=args.tile_overlap, crop=cropsize, with_conf=with_conf, return_time=True) + + if do_flip: + pred = pred.flip(dims=[3]) + + if metrics is not None: + metrics.add_batch(pred, gt) + + if any(k in args.save for k in ['pred','visu','err10','submission']): + _save_batch(pred, gt, pairnames, dataset, task, args.save, outdir, time, submission_dir=submission_dir) + + + # print + if metrics is not None: + results = metrics.get_results() + for k,v in results.items(): + print('{:s}: {:.3f}'.format(k, v)) + + # save if needed + if 'metrics' in args.save: + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, 'wb') as fid: + pickle.dump(results, fid) + print('metrics saved in', fname) + + # finalize submission if needed + if 'submission' in args.save: + dataset.finalize_submission(submission_dir) + + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/stereoflow/train.py b/imcui/third_party/dust3r/croco/stereoflow/train.py new file mode 100644 index 0000000000000000000000000000000000000000..91f2414ffbe5ecd547d31c0e2455478d402719d6 --- /dev/null +++ b/imcui/third_party/dust3r/croco/stereoflow/train.py @@ -0,0 +1,253 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main training function +# -------------------------------------------------------- + +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time + +import torch +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from torch.utils.data import DataLoader + +import utils +import utils.misc as misc +from utils.misc import NativeScalerWithGradNormCount as NativeScaler +from models.croco_downstream import CroCoDownstreamBinocular, croco_args_from_ckpt +from models.pos_embed import interpolate_pos_embed +from models.head_downstream import PixelwiseTaskWithDPT + +from stereoflow.datasets_stereo import get_train_dataset_stereo, get_test_datasets_stereo +from stereoflow.datasets_flow import get_train_dataset_flow, get_test_datasets_flow +from stereoflow.engine import train_one_epoch, validate_one_epoch +from stereoflow.criterion import * + + +def get_args_parser(): + # prepare subparsers + parser = argparse.ArgumentParser('Finetuning CroCo models on stereo or flow', add_help=False) + subparsers = parser.add_subparsers(title="Task (stereo or flow)", dest="task", required=True) + parser_stereo = subparsers.add_parser('stereo', help='Training stereo model') + parser_flow = subparsers.add_parser('flow', help='Training flow model') + def add_arg(name_or_flags, default=None, default_stereo=None, default_flow=None, **kwargs): + if default is not None: assert default_stereo is None and default_flow is None, "setting default makes default_stereo and default_flow disabled" + parser_stereo.add_argument(name_or_flags, default=default if default is not None else default_stereo, **kwargs) + parser_flow.add_argument(name_or_flags, default=default if default is not None else default_flow, **kwargs) + # output dir + add_arg('--output_dir', required=True, type=str, help='path where to save, if empty, automatically created') + # model + add_arg('--crop', type=int, nargs = '+', default_stereo=[352, 704], default_flow=[320, 384], help = "size of the random image crops used during training.") + add_arg('--pretrained', required=True, type=str, help="Load pretrained model (required as croco arguments come from there)") + # criterion + add_arg('--criterion', default_stereo='LaplacianLossBounded2()', default_flow='LaplacianLossBounded()', type=str, help='string to evaluate to get criterion') + add_arg('--bestmetric', default_stereo='avgerr', default_flow='EPE', type=str) + # dataset + add_arg('--dataset', type=str, required=True, help="training set") + # training + add_arg('--seed', default=0, type=int, help='seed') + add_arg('--batch_size', default_stereo=6, default_flow=8, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + add_arg('--epochs', default=32, type=int, help='number of training epochs') + add_arg('--img_per_epoch', type=int, default=None, help='Fix the number of images seen in an epoch (None means use all training pairs)') + add_arg('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + add_arg('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') + add_arg('--lr', type=float, default_stereo=3e-5, default_flow=2e-5, metavar='LR', help='learning rate (absolute lr)') + add_arg('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') + add_arg('--warmup_epochs', type=int, default=1, metavar='N', help='epochs to warmup LR') + add_arg('--optimizer', default='AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))', type=str, + help="Optimizer from torch.optim [ default: AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) ]") + add_arg('--amp', default=0, type=int, choices=[0,1], help='enable automatic mixed precision training') + # validation + add_arg('--val_dataset', type=str, default='', help="Validation sets, multiple separated by + (empty string means that no validation is performed)") + add_arg('--tile_conf_mode', type=str, default_stereo='conf_expsigmoid_15_3', default_flow='conf_expsigmoid_10_5', help='Weights for tile aggregation') + add_arg('--val_overlap', default=0.7, type=float, help='Overlap value for the tiling') + # others + add_arg('--num_workers', default=8, type=int) + add_arg('--eval_every', type=int, default=1, help='Val loss evaluation frequency') + add_arg('--save_every', type=int, default=1, help='Save checkpoint frequency') + add_arg('--start_from', type=str, default=None, help='Start training using weights from an other model (eg for finetuning)') + add_arg('--tboard_log_step', type=int, default=100, help='Log to tboard every so many steps') + add_arg('--dist_url', default='env://', help='url used to set up distributed training') + + return parser + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + num_tasks = misc.get_world_size() + + assert os.path.isfile(args.pretrained) + print("output_dir: "+args.output_dir) + os.makedirs(args.output_dir, exist_ok=True) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + # Metrics / criterion + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + metrics = (StereoMetrics if args.task=='stereo' else FlowMetrics)().to(device) + criterion = eval(args.criterion).to(device) + print('Criterion: ', args.criterion) + + # Prepare model + assert os.path.isfile(args.pretrained) + ckpt = torch.load(args.pretrained, 'cpu') + croco_args = croco_args_from_ckpt(ckpt) + croco_args['img_size'] = (args.crop[0], args.crop[1]) + print('Croco args: '+str(croco_args)) + args.croco_args = croco_args # saved for test time + # prepare head + num_channels = {'stereo': 1, 'flow': 2}[args.task] + if criterion.with_conf: num_channels += 1 + print(f'Building head PixelwiseTaskWithDPT() with {num_channels} channel(s)') + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + # build model and load pretrained weights + model = CroCoDownstreamBinocular(head, **croco_args) + interpolate_pos_embed(model, ckpt['model']) + msg = model.load_state_dict(ckpt['model'], strict=False) + print(msg) + + total_params = sum(p.numel() for p in model.parameters()) + total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Total params: {total_params}") + print(f"Total params trainable: {total_params_trainable}") + model_without_ddp = model.to(device) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + print("lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], static_graph=True) + model_without_ddp = model.module + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) + optimizer = eval(f"torch.optim.{args.optimizer}") + print(optimizer) + loss_scaler = NativeScaler() + + # automatic restart + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + if not args.resume and args.start_from: + print(f"Starting from an other model's weights: {args.start_from}") + best_so_far = None + args.start_epoch = 0 + ckpt = torch.load(args.start_from, 'cpu') + msg = model_without_ddp.load_state_dict(ckpt['model'], strict=False) + print(msg) + else: + best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if best_so_far is None: best_so_far = np.inf + + # tensorboard + log_writer = None + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir, purge_step=args.start_epoch*1000) + + # dataset and loader + print('Building Train Data loader for dataset: ', args.dataset) + train_dataset = (get_train_dataset_stereo if args.task=='stereo' else get_train_dataset_flow)(args.dataset, crop_size=args.crop) + def _print_repr_dataset(d): + if isinstance(d, torch.utils.data.dataset.ConcatDataset): + for dd in d.datasets: + _print_repr_dataset(dd) + else: + print(repr(d)) + _print_repr_dataset(train_dataset) + print(' total length:', len(train_dataset)) + if args.distributed: + sampler_train = torch.utils.data.DistributedSampler( + train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.RandomSampler(train_dataset) + data_loader_train = torch.utils.data.DataLoader( + train_dataset, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + if args.val_dataset=='': + data_loaders_val = None + else: + print('Building Val Data loader for datasets: ', args.val_dataset) + val_datasets = (get_test_datasets_stereo if args.task=='stereo' else get_test_datasets_flow)(args.val_dataset) + for val_dataset in val_datasets: print(repr(val_dataset)) + data_loaders_val = [DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for val_dataset in val_datasets] + bestmetric = ("AVG_" if len(data_loaders_val)>1 else str(data_loaders_val[0].dataset)+'_')+args.bestmetric + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + # Training Loop + for epoch in range(args.start_epoch, args.epochs): + + if args.distributed: data_loader_train.sampler.set_epoch(epoch) + + # Train + epoch_start = time.time() + train_stats = train_one_epoch(model, criterion, metrics, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args) + epoch_time = time.time() - epoch_start + + if args.distributed: dist.barrier() + + # Validation (current naive implementation runs the validation on every gpu ... not smart ...) + if data_loaders_val is not None and args.eval_every > 0 and (epoch+1) % args.eval_every == 0: + val_epoch_start = time.time() + val_stats = validate_one_epoch(model, criterion, metrics, data_loaders_val, device, epoch, log_writer=log_writer, args=args) + val_epoch_time = time.time() - val_epoch_start + + val_best = val_stats[bestmetric] + + # Save best of all + if val_best <= best_so_far: + best_so_far = val_best + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='best') + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + **{f'val_{k}': v for k, v in val_stats.items()}} + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch,} + + if args.distributed: dist.barrier() + + # Save stuff + if args.output_dir and ((epoch+1) % args.save_every == 0 or epoch + 1 == args.epochs): + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='last') + + if args.output_dir: + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/imcui/third_party/dust3r/croco/utils/misc.py b/imcui/third_party/dust3r/croco/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..132e102a662c987dce5282633cb8730b0e0d5c2d --- /dev/null +++ b/imcui/third_party/dust3r/croco/utils/misc.py @@ -0,0 +1,463 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +import math +import json +from collections import defaultdict, deque +from pathlib import Path +import numpy as np + +import torch +import torch.distributed as dist +from torch import inf + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, max_iter=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable) + space_fmt = ':' + str(len(str(len_iterable))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for it,obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len_iterable - 1: + eta_seconds = iter_time.global_avg * (len_iterable - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + if max_iter and it >= max_iter: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len_iterable)) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + nodist = args.nodist if hasattr(args,'nodist') else False + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and not nodist: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self, enabled=True): + self._scaler = torch.cuda.amp.GradScaler(enabled=enabled) + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + + + +def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None): + output_dir = Path(args.output_dir) + if fname is None: fname = str(epoch) + checkpoint_path = output_dir / ('checkpoint-%s.pth' % fname) + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': loss_scaler.state_dict(), + 'args': args, + 'epoch': epoch, + } + if best_so_far is not None: to_save['best_so_far'] = best_so_far + print(f'>> Saving model to {checkpoint_path} ...') + save_on_master(to_save, checkpoint_path) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + args.start_epoch = 0 + best_so_far = None + if args.resume is not None: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + print("Resume checkpoint %s" % args.resume) + model_without_ddp.load_state_dict(checkpoint['model'], strict=False) + args.start_epoch = checkpoint['epoch'] + 1 + optimizer.load_state_dict(checkpoint['optimizer']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + if 'best_so_far' in checkpoint: + best_so_far = checkpoint['best_so_far'] + print(" & best_so_far={:g}".format(best_so_far)) + else: + print("") + print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end='') + return best_so_far + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + +def _replace(text, src, tgt, rm=''): + """ Advanced string replacement. + Given a text: + - replace all elements in src by the corresponding element in tgt + - remove all elements in rm + """ + if len(tgt) == 1: + tgt = tgt * len(src) + assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len" + for s,t in zip(src, tgt): + text = text.replace(s,t) + for c in rm: + text = text.replace(c,'') + return text + +def filename( obj ): + """ transform a python obj or cmd into a proper filename. + - \1 gets replaced by slash '/' + - \2 gets replaced by comma ',' + """ + if not isinstance(obj, str): + obj = repr(obj) + obj = str(obj).replace('()','') + obj = _replace(obj, '_,(*/\1\2','-__x%/,', rm=' )\'"') + assert all(len(s) < 256 for s in obj.split(os.sep)), 'filename too long (>256 characters):\n'+obj + return obj + +def _get_num_layer_for_vit(var_name, enc_depth, dec_depth): + if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("enc_blocks"): + layer_id = int(var_name.split('.')[1]) + return layer_id + 1 + elif var_name.startswith('decoder_embed') or var_name.startswith('enc_norm'): # part of the last black + return enc_depth + elif var_name.startswith('dec_blocks'): + layer_id = int(var_name.split('.')[1]) + return enc_depth + layer_id + 1 + elif var_name.startswith('dec_norm'): # part of the last block + return enc_depth + dec_depth + elif any(var_name.startswith(k) for k in ['head','prediction_head']): + return enc_depth + dec_depth + 1 + else: + raise NotImplementedError(var_name) + +def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]): + parameter_group_names = {} + parameter_group_vars = {} + enc_depth, dec_depth = None, None + # prepare layer decay values + assert layer_decay==1.0 or 0.> wrote {fpath}') + + print(f'Loaded {len(list_subscenes)} sub-scenes') + + # separate scenes + list_scenes = defaultdict(list) + for scene in list_subscenes: + scene, id = os.path.split(scene) + list_scenes[scene].append(id) + + list_scenes = list(list_scenes.items()) + print(f'from {len(list_scenes)} scenes in total') + + np.random.shuffle(list_scenes) + train_scenes = list_scenes[len(list_scenes) // 10:] + val_scenes = list_scenes[:len(list_scenes) // 10] + + def write_scene_list(scenes, n, fpath): + sub_scenes = [os.path.join(scene, id) for scene, ids in scenes for id in ids] + np.random.shuffle(sub_scenes) + + if len(sub_scenes) < n: + return + + with open(fpath, 'w') as f: + f.write('\n'.join(sub_scenes[:n])) + print(f'>> wrote {fpath}') + + for n in n_scenes: + write_scene_list(train_scenes, n, os.path.join(habitat_root, f'Habitat_{n}_scenes_train.txt')) + write_scene_list(val_scenes, n // 10, os.path.join(habitat_root, f'Habitat_{n//10}_scenes_val.txt')) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--root", required=True) + parser.add_argument("--n_scenes", nargs='+', default=[1_000, 10_000, 100_000, 1_000_000], type=int) + + args = parser.parse_args() + find_all_scenes(args.root, args.n_scenes) diff --git a/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..4a31f1174a234b900ecaa76705fa271baf8a5669 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py @@ -0,0 +1,170 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Render environment maps from 3D meshes using the Habitat Sim simulator. +# -------------------------------------------------------- +import numpy as np +import habitat_sim +import math +from habitat_renderer import projections + +# OpenCV to habitat camera convention transformation +R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0) + +CUBEMAP_FACE_LABELS = ["left", "front", "right", "back", "up", "down"] +# Expressed while considering Habitat coordinates systems +CUBEMAP_FACE_ORIENTATIONS_ROTVEC = [ + [0, math.pi / 2, 0], # Left + [0, 0, 0], # Front + [0, - math.pi / 2, 0], # Right + [0, math.pi, 0], # Back + [math.pi / 2, 0, 0], # Up + [-math.pi / 2, 0, 0],] # Down + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + +class HabitatEnvironmentMapRenderer: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + render_equirectangular=False, + equirectangular_resolution=(512, 1024), + render_cubemap=False, + cubemap_resolution=(512, 512), + render_depth=False, + gpu_id=0): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.gpu_id = gpu_id + + self.render_equirectangular = render_equirectangular + self.equirectangular_resolution = equirectangular_resolution + self.equirectangular_projection = projections.EquirectangularProjection(*equirectangular_resolution) + # 3D unit ray associated to each pixel of the equirectangular map + equirectangular_rays = projections.get_projection_rays(self.equirectangular_projection) + # Not needed, but just in case. + equirectangular_rays /= np.linalg.norm(equirectangular_rays, axis=-1, keepdims=True) + # Depth map created by Habitat are produced by warping a cubemap, + # so the values do not correspond to distance to the center and need some scaling. + self.equirectangular_depth_scale_factors = 1.0 / np.max(np.abs(equirectangular_rays), axis=-1) + + self.render_cubemap = render_cubemap + self.cubemap_resolution = cubemap_resolution + + self.render_depth = render_depth + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly + if self.seed == None: + # Re-seed numpy generator + np.random.seed() + self.seed = np.random.randint(2**32-1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "": + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + sensor_specifications = [] + + # Add cubemaps + if self.render_cubemap: + for face_id, orientation in enumerate(CUBEMAP_FACE_ORIENTATIONS_ROTVEC): + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = f"color_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.cubemap_resolution + rgb_sensor_spec.hfov = 90 + rgb_sensor_spec.position = [0.0, 0.0, 0.0] + rgb_sensor_spec.orientation = orientation + sensor_specifications.append(rgb_sensor_spec) + + if self.render_depth: + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = f"depth_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.cubemap_resolution + depth_sensor_spec.hfov = 90 + depth_sensor_spec.position = [0.0, 0.0, 0.0] + depth_sensor_spec.orientation = orientation + sensor_specifications.append(depth_sensor_spec) + + # Add equirectangular map + if self.render_equirectangular: + rgb_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() + rgb_sensor_spec.uuid = "color_equirectangular" + rgb_sensor_spec.resolution = self.equirectangular_resolution + rgb_sensor_spec.position = [0.0, 0.0, 0.0] + sensor_specifications.append(rgb_sensor_spec) + + if self.render_depth: + depth_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() + depth_sensor_spec.uuid = "depth_equirectangular" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.equirectangular_resolution + depth_sensor_spec.position = [0.0, 0.0, 0.0] + depth_sensor_spec.orientation + sensor_specifications.append(depth_sensor_spec) + + agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=sensor_specifications) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + # Use pre-computed navmesh (the one generated automatically does some weird stuffs like going on top of the roof) + # See https://youtu.be/kunFMRJAu2U?t=1522 regarding navmeshes + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + # Check that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + # Try to compute a navmesh + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + # Check that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})") + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + if hasattr(self, 'sim'): + self.sim.close() + + def __del__(self): + self.close() + + def render_viewpoint(self, viewpoint_position): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + # agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + + try: + # Depth map values have been obtained using cubemap rendering internally, + # so they do not really correspond to distance to the viewpoint in practice + # and they need some scaling + viewpoint_observations["depth_equirectangular"] *= self.equirectangular_depth_scale_factors + except KeyError: + pass + + data = dict(observations=viewpoint_observations, position=viewpoint_position) + return data + + def up_direction(self): + return np.asarray(habitat_sim.geo.UP).tolist() + + def R_cam_to_world(self): + return R_OPENCV2HABITAT.tolist() diff --git a/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b86238b44a5cdd7a2e30b9d64773c2388f9711c3 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py @@ -0,0 +1,93 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Generate pairs of crops from a dataset of environment maps. +# -------------------------------------------------------- +import os +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 +import collections +from habitat_renderer import projections, projections_conversions +from habitat_renderer.habitat_sim_envmaps_renderer import HabitatEnvironmentMapRenderer + +ViewpointData = collections.namedtuple("ViewpointData", ["colormap", "distancemap", "pointmap", "position"]) + +class HabitatMultiviewCrops: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + equirectangular_resolution=(400, 800), + crop_resolution=(240, 320), + pixel_jittering_iterations=5, + jittering_noise_level=1.0): + self.crop_resolution = crop_resolution + + self.pixel_jittering_iterations = pixel_jittering_iterations + self.jittering_noise_level = jittering_noise_level + + # Instanciate the low resolution habitat sim renderer + self.lowres_envmap_renderer = HabitatEnvironmentMapRenderer(scene=scene, + navmesh=navmesh, + scene_dataset_config_file=scene_dataset_config_file, + equirectangular_resolution=equirectangular_resolution, + render_depth=True, + render_equirectangular=True) + self.R_cam_to_world = np.asarray(self.lowres_envmap_renderer.R_cam_to_world()) + self.up_direction = np.asarray(self.lowres_envmap_renderer.up_direction()) + + # Projection applied by each environment map + self.envmap_height, self.envmap_width = self.lowres_envmap_renderer.equirectangular_resolution + base_projection = projections.EquirectangularProjection(self.envmap_height, self.envmap_width) + self.envmap_projection = projections.RotatedProjection(base_projection, self.R_cam_to_world.T) + # 3D Rays map associated to each envmap + self.envmap_rays = projections.get_projection_rays(self.envmap_projection) + + def compute_pointmap(self, distancemap, position): + # Point cloud associated to each ray + return self.envmap_rays * distancemap[:, :, None] + position + + def render_viewpoint_data(self, position): + data = self.lowres_envmap_renderer.render_viewpoint(np.asarray(position)) + colormap = data['observations']['color_equirectangular'][..., :3] # Ignore the alpha channel + distancemap = data['observations']['depth_equirectangular'] + pointmap = self.compute_pointmap(distancemap, position) + return ViewpointData(colormap=colormap, distancemap=distancemap, pointmap=pointmap, position=position) + + def extract_cropped_camera(self, projection, color_image, distancemap, pointmap, voxelmap=None): + remapper = projections_conversions.RemapProjection(input_projection=self.envmap_projection, output_projection=projection, + pixel_jittering_iterations=self.pixel_jittering_iterations, jittering_noise_level=self.jittering_noise_level) + cropped_color_image = remapper.convert( + color_image, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False) + cropped_distancemap = remapper.convert( + distancemap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True) + cropped_pointmap = remapper.convert(pointmap, interpolation=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_WRAP, single_map=True) + cropped_voxelmap = (None if voxelmap is None else + remapper.convert(voxelmap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True)) + # Convert the distance map into a depth map + cropped_depthmap = np.asarray( + cropped_distancemap / np.linalg.norm(remapper.output_rays, axis=-1), dtype=cropped_distancemap.dtype) + + return cropped_color_image, cropped_depthmap, cropped_pointmap, cropped_voxelmap + +def perspective_projection_to_dict(persp_projection, position): + """ + Serialization-like function.""" + camera_params = dict(camera_intrinsics=projections.colmap_to_opencv_intrinsics(persp_projection.base_projection.K).tolist(), + size=(persp_projection.base_projection.width, persp_projection.base_projection.height), + R_cam2world=persp_projection.R_to_base_projection.T.tolist(), + t_cam2world=position) + return camera_params + + +def dict_to_perspective_projection(camera_params): + K = projections.opencv_to_colmap_intrinsics(np.asarray(camera_params["camera_intrinsics"])) + size = camera_params["size"] + R_cam2world = np.asarray(camera_params["R_cam2world"]) + projection = projections.PerspectiveProjection(K, height=size[1], width=size[0]) + projection = projections.RotatedProjection(projection, R_to_base_projection=R_cam2world.T) + position = camera_params["t_cam2world"] + return projection, position \ No newline at end of file diff --git a/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py new file mode 100644 index 0000000000000000000000000000000000000000..4db1f79d23e23a8ba144b4357c4d4daf10cf8fab --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py @@ -0,0 +1,151 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Various 3D/2D projection utils, useful to sample virtual cameras. +# -------------------------------------------------------- +import numpy as np + +class EquirectangularProjection: + """ + Convention for the central pixel of the equirectangular map similar to OpenCV perspective model: + +X from left to right + +Y from top to bottom + +Z going outside the camera + EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)) + """ + + def __init__(self, height, width): + self.height = height + self.width = width + self.u_scaling = (2 * np.pi) / self.width + self.v_scaling = np.pi / self.height + + def unproject(self, u, v): + """ + Args: + u, v: 2D coordinates + Returns: + unnormalized 3D rays. + """ + longitude = self.u_scaling * u - np.pi + minus_latitude = self.v_scaling * v - np.pi/2 + + cos_latitude = np.cos(minus_latitude) + x, z = np.sin(longitude) * cos_latitude, np.cos(longitude) * cos_latitude + y = np.sin(minus_latitude) + + rays = np.stack([x, y, z], axis=-1) + return rays + + def project(self, rays): + """ + Args: + rays: Bx3 array of 3D rays. + Returns: + u, v: tuple of 2D coordinates. + """ + rays = rays / np.linalg.norm(rays, axis=-1, keepdims=True) + x, y, z = [rays[..., i] for i in range(3)] + + longitude = np.arctan2(x, z) + minus_latitude = np.arcsin(y) + + u = (longitude + np.pi) * (1.0 / self.u_scaling) + v = (minus_latitude + np.pi/2) * (1.0 / self.v_scaling) + return u, v + + +class PerspectiveProjection: + """ + OpenCV convention: + World space: + +X from left to right + +Y from top to bottom + +Z going outside the camera + Pixel space: + +u from left to right + +v from top to bottom + EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)). + """ + + def __init__(self, K, height, width): + self.height = height + self.width = width + self.K = K + self.Kinv = np.linalg.inv(K) + + def project(self, rays): + uv_homogeneous = np.einsum("ik, ...k -> ...i", self.K, rays) + uv = uv_homogeneous[..., :2] / uv_homogeneous[..., 2, None] + return uv[..., 0], uv[..., 1] + + def unproject(self, u, v): + uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1) + rays = np.einsum("ik, ...k -> ...i", self.Kinv, uv_homogeneous) + return rays + + +class RotatedProjection: + def __init__(self, base_projection, R_to_base_projection): + self.base_projection = base_projection + self.R_to_base_projection = R_to_base_projection + + @property + def width(self): + return self.base_projection.width + + @property + def height(self): + return self.base_projection.height + + def project(self, rays): + if self.R_to_base_projection is not None: + rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection, rays) + return self.base_projection.project(rays) + + def unproject(self, u, v): + rays = self.base_projection.unproject(u, v) + if self.R_to_base_projection is not None: + rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection.T, rays) + return rays + +def get_projection_rays(projection, noise_level=0): + """ + Return a 2D map of 3D rays corresponding to the projection. + If noise_level > 0, add some jittering noise to these rays. + """ + grid_u, grid_v = np.meshgrid(0.5 + np.arange(projection.width), 0.5 + np.arange(projection.height)) + if noise_level > 0: + grid_u += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_u.shape), projection.width) + grid_v += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_v.shape), projection.height) + return projection.unproject(grid_u, grid_v) + +def compute_camera_intrinsics(height, width, hfov): + f = width/2 / np.tan(hfov/2 * np.pi/180) + cu, cv = width/2, height/2 + return f, cu, cv + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K \ No newline at end of file diff --git a/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcfed4066bbac62fa4254ea6417bf429b098b75 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py @@ -0,0 +1,45 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Remap data from one projection to an other +# -------------------------------------------------------- +import numpy as np +import cv2 +from habitat_renderer import projections + +class RemapProjection: + def __init__(self, input_projection, output_projection, pixel_jittering_iterations=0, jittering_noise_level=0): + """ + Some naive random jittering can be introduced in the remapping to mitigate aliasing artecfacts. + """ + assert jittering_noise_level >= 0 + assert pixel_jittering_iterations >= 0 + + maps = [] + # Initial map + self.output_rays = projections.get_projection_rays(output_projection) + map_u, map_v = input_projection.project(self.output_rays) + map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) + maps.append((map_u, map_v)) + + for _ in range(pixel_jittering_iterations): + # Define multiple mappings using some coordinates jittering to mitigate aliasing effects + crop_rays = projections.get_projection_rays(output_projection, jittering_noise_level) + map_u, map_v = input_projection.project(crop_rays) + map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) + maps.append((map_u, map_v)) + self.maps = maps + + def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False): + remapped = [] + for map_u, map_v in self.maps: + res = cv2.remap(img, map_u, map_v, interpolation=interpolation, borderMode=borderMode) + remapped.append(res) + if single_map: + break + if len(remapped) == 1: + res = remapped[0] + else: + res = np.asarray(np.mean(remapped, axis=0), dtype=img.dtype) + return res diff --git a/imcui/third_party/dust3r/datasets_preprocess/habitat/preprocess_habitat.py b/imcui/third_party/dust3r/datasets_preprocess/habitat/preprocess_habitat.py new file mode 100644 index 0000000000000000000000000000000000000000..cacbe2467a8e9629c2472b0e05fc0cf8326367e2 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/habitat/preprocess_habitat.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# main executable for preprocessing habitat +# export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata" +# export SCENES_DIR="/path/to/habitat/data/scene_datasets/" +# export OUTPUT_DIR="data/habitat_processed" +# export PYTHONPATH=$(pwd) +# python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16 +# -------------------------------------------------------- +import os +import glob +import json +import os + +import PIL.Image +import json +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 +from habitat_renderer import multiview_crop_generator +from tqdm import tqdm + + +def preprocess_metadata(metadata_filename, + scenes_dir, + output_dir, + crop_resolution=[512, 512], + equirectangular_resolution=None, + fix_existing_dataset=False): + # Load data + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + if metadata["scene_dataset_config_file"] == "": + scene = os.path.join(scenes_dir, metadata["scene"]) + scene_dataset_config_file = "" + else: + scene = metadata["scene"] + scene_dataset_config_file = os.path.join(scenes_dir, metadata["scene_dataset_config_file"]) + navmesh = None + + # Use 4 times the crop size as resolution for rendering the environment map. + max_res = max(crop_resolution) + + if equirectangular_resolution == None: + # Use 4 times the crop size as resolution for rendering the environment map. + max_res = max(crop_resolution) + equirectangular_resolution = (4*max_res, 8*max_res) + + print("equirectangular_resolution:", equirectangular_resolution) + + if os.path.exists(output_dir) and not fix_existing_dataset: + raise FileExistsError(output_dir) + + # Lazy initialization + highres_dataset = None + + for batch_label, batch in tqdm(metadata["view_batches"].items()): + for view_label, view_params in batch.items(): + + assert view_params["size"] == crop_resolution + label = f"{batch_label}_{view_label}" + + output_camera_params_filename = os.path.join(output_dir, f"{label}_camera_params.json") + if fix_existing_dataset and os.path.isfile(output_camera_params_filename): + # Skip generation if we are fixing a dataset and the corresponding output file already exists + continue + + # Lazy initialization + if highres_dataset is None: + highres_dataset = multiview_crop_generator.HabitatMultiviewCrops(scene=scene, + navmesh=navmesh, + scene_dataset_config_file=scene_dataset_config_file, + equirectangular_resolution=equirectangular_resolution, + crop_resolution=crop_resolution,) + os.makedirs(output_dir, exist_ok=bool(fix_existing_dataset)) + + # Generate a higher resolution crop + original_projection, position = multiview_crop_generator.dict_to_perspective_projection(view_params) + # Render an envmap at the given position + viewpoint_data = highres_dataset.render_viewpoint_data(position) + + projection = original_projection + colormap, depthmap, pointmap, _ = highres_dataset.extract_cropped_camera( + projection, viewpoint_data.colormap, viewpoint_data.distancemap, viewpoint_data.pointmap) + + camera_params = multiview_crop_generator.perspective_projection_to_dict(projection, position) + + # Color image + PIL.Image.fromarray(colormap).save(os.path.join(output_dir, f"{label}.jpeg")) + # Depth image + cv2.imwrite(os.path.join(output_dir, f"{label}_depth.exr"), + depthmap, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + with open(output_camera_params_filename, "w") as f: + json.dump(camera_params, f) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_dir", required=True) + parser.add_argument("--scenes_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--metadata_filename", default="") + + args = parser.parse_args() + + if args.metadata_filename == "": + # Walk through the metadata dir to generate commandlines + for filename in glob.iglob(os.path.join(args.metadata_dir, "**/metadata.json"), recursive=True): + output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(filename), args.metadata_dir)) + if not os.path.exists(output_dir): + commandline = f"python {__file__} --metadata_filename={filename} --metadata_dir={args.metadata_dir} --scenes_dir={args.scenes_dir} --output_dir={output_dir}" + print(commandline) + else: + preprocess_metadata(metadata_filename=args.metadata_filename, + scenes_dir=args.scenes_dir, + output_dir=args.output_dir) diff --git a/imcui/third_party/dust3r/datasets_preprocess/path_to_root.py b/imcui/third_party/dust3r/datasets_preprocess/path_to_root.py new file mode 100644 index 0000000000000000000000000000000000000000..6e076a17a408d0a9e043fbda2d73f1592e7cb71a --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/path_to_root.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R repo root import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../')) +# workaround for sibling import +sys.path.insert(0, DUST3R_REPO_PATH) diff --git a/imcui/third_party/dust3r/datasets_preprocess/preprocess_arkitscenes.py b/imcui/third_party/dust3r/datasets_preprocess/preprocess_arkitscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbc103a82d646293e1d81f5132683e2b08cd879 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/preprocess_arkitscenes.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the arkitscenes dataset. +# Usage: +# python3 datasets_preprocess/preprocess_arkitscenes.py --arkitscenes_dir /path/to/arkitscenes --precomputed_pairs /path/to/arkitscenes_pairs +# -------------------------------------------------------- +import os +import json +import os.path as osp +import decimal +import argparse +import math +from bisect import bisect_left +from PIL import Image +import numpy as np +import quaternion +from scipy import interpolate +import cv2 + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--arkitscenes_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/arkitscenes_processed') + return parser + + +def value_to_decimal(value, decimal_places): + decimal.getcontext().rounding = decimal.ROUND_HALF_UP # define rounding method + return decimal.Decimal(str(float(value))).quantize(decimal.Decimal('1e-{}'.format(decimal_places))) + + +def closest(value, sorted_list): + index = bisect_left(sorted_list, value) + if index == 0: + return sorted_list[0] + elif index == len(sorted_list): + return sorted_list[-1] + else: + value_before = sorted_list[index - 1] + value_after = sorted_list[index] + if value_after - value < value - value_before: + return value_after + else: + return value_before + + +def get_up_vectors(pose_device_to_world): + return np.matmul(pose_device_to_world, np.array([[0.0], [-1.0], [0.0], [0.0]])) + + +def get_right_vectors(pose_device_to_world): + return np.matmul(pose_device_to_world, np.array([[1.0], [0.0], [0.0], [0.0]])) + + +def read_traj(traj_path): + quaternions = [] + poses = [] + timestamps = [] + poses_p_to_w = [] + with open(traj_path) as f: + traj_lines = f.readlines() + for line in traj_lines: + tokens = line.split() + assert len(tokens) == 7 + traj_timestamp = float(tokens[0]) + + timestamps_decimal_value = value_to_decimal(traj_timestamp, 3) + timestamps.append(float(timestamps_decimal_value)) # for spline interpolation + + angle_axis = [float(tokens[1]), float(tokens[2]), float(tokens[3])] + r_w_to_p, _ = cv2.Rodrigues(np.asarray(angle_axis)) + t_w_to_p = np.asarray([float(tokens[4]), float(tokens[5]), float(tokens[6])]) + + pose_w_to_p = np.eye(4) + pose_w_to_p[:3, :3] = r_w_to_p + pose_w_to_p[:3, 3] = t_w_to_p + + pose_p_to_w = np.linalg.inv(pose_w_to_p) + + r_p_to_w_as_quat = quaternion.from_rotation_matrix(pose_p_to_w[:3, :3]) + t_p_to_w = pose_p_to_w[:3, 3] + poses_p_to_w.append(pose_p_to_w) + poses.append(t_p_to_w) + quaternions.append(r_p_to_w_as_quat) + return timestamps, poses, quaternions, poses_p_to_w + + +def main(rootdir, pairsdir, outdir): + os.makedirs(outdir, exist_ok=True) + + subdirs = ['Test', 'Training'] + for subdir in subdirs: + if not osp.isdir(osp.join(rootdir, subdir)): + continue + # STEP 1: list all scenes + outsubdir = osp.join(outdir, subdir) + os.makedirs(outsubdir, exist_ok=True) + listfile = osp.join(pairsdir, subdir, 'scene_list.json') + with open(listfile, 'r') as f: + scene_dirs = json.load(f) + + valid_scenes = [] + for scene_subdir in scene_dirs: + out_scene_subdir = osp.join(outsubdir, scene_subdir) + os.makedirs(out_scene_subdir, exist_ok=True) + + scene_dir = osp.join(rootdir, subdir, scene_subdir) + depth_dir = osp.join(scene_dir, 'lowres_depth') + rgb_dir = osp.join(scene_dir, 'vga_wide') + intrinsics_dir = osp.join(scene_dir, 'vga_wide_intrinsics') + traj_path = osp.join(scene_dir, 'lowres_wide.traj') + + # STEP 2: read selected_pairs.npz + selected_pairs_path = osp.join(pairsdir, subdir, scene_subdir, 'selected_pairs.npz') + selected_npz = np.load(selected_pairs_path) + selection, pairs = selected_npz['selection'], selected_npz['pairs'] + selected_sky_direction_scene = str(selected_npz['sky_direction_scene'][0]) + if len(selection) == 0 or len(pairs) == 0: + # not a valid scene + continue + valid_scenes.append(scene_subdir) + + # STEP 3: parse the scene and export the list of valid (K, pose, rgb, depth) and convert images + scene_metadata_path = osp.join(out_scene_subdir, 'scene_metadata.npz') + if osp.isfile(scene_metadata_path): + continue + else: + print(f'parsing {scene_subdir}') + # loads traj + timestamps, poses, quaternions, poses_cam_to_world = read_traj(traj_path) + + poses = np.array(poses) + quaternions = np.array(quaternions, dtype=np.quaternion) + quaternions = quaternion.unflip_rotors(quaternions) + timestamps = np.array(timestamps) + + selected_images = [(basename, basename.split(".png")[0].split("_")[1]) for basename in selection] + timestamps_selected = [float(frame_id) for _, frame_id in selected_images] + + sky_direction_scene, trajectories, intrinsics, images = convert_scene_metadata(scene_subdir, + intrinsics_dir, + timestamps, + quaternions, + poses, + poses_cam_to_world, + selected_images, + timestamps_selected) + assert selected_sky_direction_scene == sky_direction_scene + + os.makedirs(os.path.join(out_scene_subdir, 'vga_wide'), exist_ok=True) + os.makedirs(os.path.join(out_scene_subdir, 'lowres_depth'), exist_ok=True) + assert isinstance(sky_direction_scene, str) + for basename in images: + img_out = os.path.join(out_scene_subdir, 'vga_wide', basename.replace('.png', '.jpg')) + depth_out = os.path.join(out_scene_subdir, 'lowres_depth', basename) + if osp.isfile(img_out) and osp.isfile(depth_out): + continue + + vga_wide_path = osp.join(rgb_dir, basename) + depth_path = osp.join(depth_dir, basename) + + img = Image.open(vga_wide_path) + depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) + + # rotate the image + if sky_direction_scene == 'RIGHT': + try: + img = img.transpose(Image.Transpose.ROTATE_90) + except Exception: + img = img.transpose(Image.ROTATE_90) + depth = cv2.rotate(depth, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif sky_direction_scene == 'LEFT': + try: + img = img.transpose(Image.Transpose.ROTATE_270) + except Exception: + img = img.transpose(Image.ROTATE_270) + depth = cv2.rotate(depth, cv2.ROTATE_90_CLOCKWISE) + elif sky_direction_scene == 'DOWN': + try: + img = img.transpose(Image.Transpose.ROTATE_180) + except Exception: + img = img.transpose(Image.ROTATE_180) + depth = cv2.rotate(depth, cv2.ROTATE_180) + + W, H = img.size + if not osp.isfile(img_out): + img.save(img_out) + + depth = cv2.resize(depth, (W, H), interpolation=cv2.INTER_NEAREST_EXACT) + if not osp.isfile(depth_out): # avoid destroying the base dataset when you mess up the paths + cv2.imwrite(depth_out, depth) + + # save at the end + np.savez(scene_metadata_path, + trajectories=trajectories, + intrinsics=intrinsics, + images=images, + pairs=pairs) + + outlistfile = osp.join(outsubdir, 'scene_list.json') + with open(outlistfile, 'w') as f: + json.dump(valid_scenes, f) + + # STEP 5: concat all scene_metadata.npz into a single file + scene_data = {} + for scene_subdir in valid_scenes: + scene_metadata_path = osp.join(outsubdir, scene_subdir, 'scene_metadata.npz') + with np.load(scene_metadata_path) as data: + trajectories = data['trajectories'] + intrinsics = data['intrinsics'] + images = data['images'] + pairs = data['pairs'] + scene_data[scene_subdir] = {'trajectories': trajectories, + 'intrinsics': intrinsics, + 'images': images, + 'pairs': pairs} + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + pairs = [] + for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): + num_imgs = data['images'].shape[0] + img_pairs = data['pairs'] + + scenes.append(scene_subdir) + sceneids.extend([scene_idx] * num_imgs) + + images.append(data['images']) + + K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) + K[:, 0, 0] = [fx for _, _, fx, _, _, _ in data['intrinsics']] + K[:, 1, 1] = [fy for _, _, _, fy, _, _ in data['intrinsics']] + K[:, 0, 2] = [hw for _, _, _, _, hw, _ in data['intrinsics']] + K[:, 1, 2] = [hh for _, _, _, _, _, hh in data['intrinsics']] + + intrinsics.append(K) + trajectories.append(data['trajectories']) + + # offset pairs + img_pairs[:, 0:2] += offset + pairs.append(img_pairs) + counts.append(offset) + + offset += num_imgs + + images = np.concatenate(images, axis=0) + intrinsics = np.concatenate(intrinsics, axis=0) + trajectories = np.concatenate(trajectories, axis=0) + pairs = np.concatenate(pairs, axis=0) + np.savez(osp.join(outsubdir, 'all_metadata.npz'), + counts=counts, + scenes=scenes, + sceneids=sceneids, + images=images, + intrinsics=intrinsics, + trajectories=trajectories, + pairs=pairs) + + +def convert_scene_metadata(scene_subdir, intrinsics_dir, + timestamps, quaternions, poses, poses_cam_to_world, + selected_images, timestamps_selected): + # find scene orientation + sky_direction_scene, rotated_to_cam = find_scene_orientation(poses_cam_to_world) + + # find/compute pose for selected timestamps + # most images have a valid timestamp / exact pose associated + timestamps_selected = np.array(timestamps_selected) + spline = interpolate.interp1d(timestamps, poses, kind='linear', axis=0) + interpolated_rotations = quaternion.squad(quaternions, timestamps, timestamps_selected) + interpolated_positions = spline(timestamps_selected) + + trajectories = [] + intrinsics = [] + images = [] + for i, (basename, frame_id) in enumerate(selected_images): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{frame_id}.pincam") + if not osp.exists(intrinsic_fn): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) - 0.001:.3f}.pincam") + if not osp.exists(intrinsic_fn): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) + 0.001:.3f}.pincam") + assert osp.exists(intrinsic_fn) + w, h, fx, fy, hw, hh = np.loadtxt(intrinsic_fn) # PINHOLE + + pose = np.eye(4) + pose[:3, :3] = quaternion.as_rotation_matrix(interpolated_rotations[i]) + pose[:3, 3] = interpolated_positions[i] + + images.append(basename) + if sky_direction_scene == 'RIGHT' or sky_direction_scene == 'LEFT': + intrinsics.append([h, w, fy, fx, hh, hw]) # swapped intrinsics + else: + intrinsics.append([w, h, fx, fy, hw, hh]) + trajectories.append(pose @ rotated_to_cam) # pose_cam_to_world @ rotated_to_cam = rotated(cam) to world + + return sky_direction_scene, trajectories, intrinsics, images + + +def find_scene_orientation(poses_cam_to_world): + if len(poses_cam_to_world) > 0: + up_vector = sum(get_up_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) + right_vector = sum(get_right_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) + up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) + else: + up_vector = np.array([[0.0], [-1.0], [0.0], [0.0]]) + right_vector = np.array([[1.0], [0.0], [0.0], [0.0]]) + up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) + + # value between 0, 180 + device_up_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), + up_vector), -1.0, 1.0)).item() * 180.0 / np.pi + device_right_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), + right_vector), -1.0, 1.0)).item() * 180.0 / np.pi + + up_closest_to_90 = abs(device_up_to_world_up_angle - 90.0) < abs(device_right_to_world_up_angle - 90.0) + if up_closest_to_90: + assert abs(device_up_to_world_up_angle - 90.0) < 45.0 + # LEFT + if device_right_to_world_up_angle > 90.0: + sky_direction_scene = 'LEFT' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi / 2.0]) + else: + # note that in metadata.csv RIGHT does not exist, but again it's not accurate... + # well, turns out there are scenes oriented like this + # for example Training/41124801 + sky_direction_scene = 'RIGHT' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, -math.pi / 2.0]) + else: + # right is close to 90 + assert abs(device_right_to_world_up_angle - 90.0) < 45.0 + if device_up_to_world_up_angle > 90.0: + sky_direction_scene = 'DOWN' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi]) + else: + sky_direction_scene = 'UP' + cam_to_rotated_q = quaternion.quaternion(1, 0, 0, 0) + cam_to_rotated = np.eye(4) + cam_to_rotated[:3, :3] = quaternion.as_rotation_matrix(cam_to_rotated_q) + rotated_to_cam = np.linalg.inv(cam_to_rotated) + return sky_direction_scene, rotated_to_cam + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.arkitscenes_dir, args.precomputed_pairs, args.output_dir) diff --git a/imcui/third_party/dust3r/datasets_preprocess/preprocess_blendedMVS.py b/imcui/third_party/dust3r/datasets_preprocess/preprocess_blendedMVS.py new file mode 100644 index 0000000000000000000000000000000000000000..d22793793c1219ebb1b3ba8eff51226c2b13f657 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/preprocess_blendedMVS.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the BlendedMVS dataset +# dataset at https://github.com/YoYo000/BlendedMVS +# 1) Download BlendedMVS.zip +# 2) Download BlendedMVS+.zip +# 3) Download BlendedMVS++.zip +# 4) Unzip everything in the same /path/to/tmp/blendedMVS/ directory +# 5) python datasets_preprocess/preprocess_blendedMVS.py --blendedmvs_dir /path/to/tmp/blendedMVS/ +# -------------------------------------------------------- +import os +import os.path as osp +import re +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--blendedmvs_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/blendedmvs_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + print('>> Listing all sequences') + sequences = [f for f in os.listdir(db_root) if len(f) == 24] + # should find 502 scenes + assert sequences, f'did not found any sequences at {db_root}' + print(f' (found {len(sequences)} sequences)') + + for i, seq in enumerate(tqdm(sequences)): + out_dir = osp.join(output_dir, seq) + os.makedirs(out_dir, exist_ok=True) + + # generate the crops + root = osp.join(db_root, seq) + cam_dir = osp.join(root, 'cams') + func_args = [(root, f[:-8], out_dir) for f in os.listdir(cam_dir) if not f.startswith('pair')] + parallel_threads(load_crop_and_save, func_args, star_args=True, leave=False) + + # verify that all pairs are there + pairs = np.load(pairs_path) + for seqh, seql, img1, img2, score in tqdm(pairs): + for view_index in [img1, img2]: + impath = osp.join(output_dir, f"{seqh:08x}{seql:016x}", f"{view_index:08n}.jpg") + assert osp.isfile(impath), f'missing image at {impath=}' + + print(f'>> Done, saved everything in {output_dir}/') + + +def load_crop_and_save(root, img, out_dir): + if osp.isfile(osp.join(out_dir, img + '.npz')): + return # already done + + # load everything + intrinsics_in, R_camin2world, t_camin2world = _load_pose(osp.join(root, 'cams', img + '_cam.txt')) + color_image_in = cv2.cvtColor(cv2.imread(osp.join(root, 'blended_images', img + + '.jpg'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + depthmap_in = load_pfm_file(osp.join(root, 'rendered_depth_maps', img + '.pfm')) + + # do the crop + H, W = color_image_in.shape[:2] + assert H * 4 == W * 3 + image, depthmap, intrinsics_out, R_in2out = _crop_image(intrinsics_in, color_image_in, depthmap_in, (512, 384)) + + # write everything + image.save(osp.join(out_dir, img + '.jpg'), quality=80) + cv2.imwrite(osp.join(out_dir, img + '.exr'), depthmap) + + # New camera parameters + R_camout2world = R_camin2world @ R_in2out.T + t_camout2world = t_camin2world + np.savez(osp.join(out_dir, img + '.npz'), intrinsics=intrinsics_out, + R_cam2world=R_camout2world, t_cam2world=t_camout2world) + + +def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(800, 800)): + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + color_image_in, depthmap_in, intrinsics_in, resolution_out) + R_in2out = np.eye(3) + return image, depthmap, intrinsics_out, R_in2out + + +def _load_pose(path, ret_44=False): + f = open(path) + RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32) + assert RT.shape == (4, 4) + RT = np.linalg.inv(RT) # world2cam to cam2world + + K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32) + assert K.shape == (3, 3) + + if ret_44: + return K, RT + return K, RT[:3, :3], RT[:3, 3] # , depth_uint8_to_f32 + + +def load_pfm_file(file_path): + with open(file_path, 'rb') as file: + header = file.readline().decode('UTF-8').strip() + + if header == 'PF': + is_color = True + elif header == 'Pf': + is_color = False + else: + raise ValueError('The provided file is not a valid PFM file.') + + dimensions = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('UTF-8')) + if dimensions: + img_width, img_height = map(int, dimensions.groups()) + else: + raise ValueError('Invalid PFM header format.') + + endian_scale = float(file.readline().decode('UTF-8').strip()) + if endian_scale < 0: + dtype = '= img_size * 3/4, and max dimension will be >= img_size")) + return parser + + +def convert_ndc_to_pinhole(focal_length, principal_point, image_size): + focal_length = np.array(focal_length) + principal_point = np.array(principal_point) + image_size_wh = np.array([image_size[1], image_size[0]]) + half_image_size = image_size_wh / 2 + rescale = half_image_size.min() + principal_point_px = half_image_size - principal_point * rescale + focal_length_px = focal_length * rescale + fx, fy = focal_length_px[0], focal_length_px[1] + cx, cy = principal_point_px[0], principal_point_px[1] + K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32) + return K + + +def opencv_from_cameras_projection(R, T, focal, p0, image_size): + R = torch.from_numpy(R)[None, :, :] + T = torch.from_numpy(T)[None, :] + focal = torch.from_numpy(focal)[None, :] + p0 = torch.from_numpy(p0)[None, :] + image_size = torch.from_numpy(image_size)[None, :] + + R_pytorch3d = R.clone() + T_pytorch3d = T.clone() + focal_pytorch3d = focal + p0_pytorch3d = p0 + T_pytorch3d[:, :2] *= -1 + R_pytorch3d[:, :, :2] *= -1 + tvec = T_pytorch3d + R = R_pytorch3d.permute(0, 2, 1) + + # Retype the image_size correctly and flip to width, height. + image_size_wh = image_size.to(R).flip(dims=(1,)) + + # NDC to screen conversion. + scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 + scale = scale.expand(-1, 2) + c0 = image_size_wh / 2.0 + + principal_point = -p0_pytorch3d * scale + c0 + focal_length = focal_pytorch3d * scale + + camera_matrix = torch.zeros_like(R) + camera_matrix[:, :2, 2] = principal_point + camera_matrix[:, 2, 2] = 1.0 + camera_matrix[:, 0, 0] = focal_length[:, 0] + camera_matrix[:, 1, 1] = focal_length[:, 1] + return R[0], tvec[0], camera_matrix[0] + + +def get_set_list(category_dir, split, is_single_sequence_subset=False): + listfiles = os.listdir(osp.join(category_dir, "set_lists")) + if is_single_sequence_subset: + # not all objects have manyview_dev + subset_list_files = [f for f in listfiles if "manyview_dev" in f] + else: + subset_list_files = [f for f in listfiles if f"fewview_train" in f] + + sequences_all = [] + for subset_list_file in subset_list_files: + with open(osp.join(category_dir, "set_lists", subset_list_file)) as f: + subset_lists_data = json.load(f) + sequences_all.extend(subset_lists_data[split]) + + return sequences_all + + +def prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object, + seed, is_single_sequence_subset=False): + random.seed(seed) + category_dir = osp.join(co3d_dir, category) + category_output_dir = osp.join(output_dir, category) + sequences_all = get_set_list(category_dir, split, is_single_sequence_subset) + sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all)) + + frame_file = osp.join(category_dir, "frame_annotations.jgz") + sequence_file = osp.join(category_dir, "sequence_annotations.jgz") + + with gzip.open(frame_file, "r") as fin: + frame_data = json.loads(fin.read()) + with gzip.open(sequence_file, "r") as fin: + sequence_data = json.loads(fin.read()) + + frame_data_processed = {} + for f_data in frame_data: + sequence_name = f_data["sequence_name"] + frame_data_processed.setdefault(sequence_name, {})[f_data["frame_number"]] = f_data + + good_quality_sequences = set() + for seq_data in sequence_data: + if seq_data["viewpoint_quality_score"] > min_quality: + good_quality_sequences.add(seq_data["sequence_name"]) + + sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences] + if len(sequences_numbers) < max_num_sequences_per_object: + selected_sequences_numbers = sequences_numbers + else: + selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object) + + selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers} + sequences_all = [(seq_name, frame_number, filepath) + for seq_name, frame_number, filepath in sequences_all + if seq_name in selected_sequences_numbers_dict] + + for seq_name, frame_number, filepath in tqdm(sequences_all): + frame_idx = int(filepath.split('/')[-1][5:-4]) + selected_sequences_numbers_dict[seq_name].append(frame_idx) + mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") + frame_data = frame_data_processed[seq_name][frame_number] + focal_length = frame_data["viewpoint"]["focal_length"] + principal_point = frame_data["viewpoint"]["principal_point"] + image_size = frame_data["image"]["size"] + K = convert_ndc_to_pinhole(focal_length, principal_point, image_size) + R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data["viewpoint"]["R"]), + np.array(frame_data["viewpoint"]["T"]), + np.array(focal_length), + np.array(principal_point), + np.array(image_size)) + + frame_data = frame_data_processed[seq_name][frame_number] + depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"]) + assert frame_data["depth"]["scale_adjustment"] == 1.0 + image_path = os.path.join(co3d_dir, filepath) + mask_path_full = os.path.join(co3d_dir, mask_path) + + input_rgb_image = PIL.Image.open(image_path).convert('RGB') + input_mask = plt.imread(mask_path_full) + + with PIL.Image.open(depth_path) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + input_depthmap = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0]))) + depth_mask = np.stack((input_depthmap, input_mask), axis=-1) + H, W = input_depthmap.shape + + camera_intrinsics = camera_intrinsics.numpy() + cx, cy = camera_intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( + input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) + + # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 + scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + if max(output_resolution) < img_size: + # let's put the max dimension to img_size + scale_final = (img_size / max(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( + input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) + input_depthmap = depth_mask[:, :, 0] + input_mask = depth_mask[:, :, 1] + + # generate and adjust camera pose + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = R + camera_pose[:3, 3] = tvec + camera_pose = np.linalg.inv(camera_pose) + + # save crop images and depth, metadata + save_img_path = os.path.join(output_dir, filepath) + save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"]) + save_mask_path = os.path.join(output_dir, mask_path) + os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) + + input_rgb_image.save(save_img_path) + scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16) + cv2.imwrite(save_depth_path, scaled_depth_map) + cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) + + save_meta_path = save_img_path.replace('jpg', 'npz') + np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, + camera_pose=camera_pose, maximum_depth=np.max(input_depthmap)) + + return selected_sequences_numbers_dict + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + assert args.co3d_dir != args.output_dir + if args.category is None: + if args.single_sequence_subset: + categories = SINGLE_SEQUENCE_CATEGORIES + else: + categories = CATEGORIES + else: + categories = [args.category] + os.makedirs(args.output_dir, exist_ok=True) + + for split in ['train', 'test']: + selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(selected_sequences_path): + continue + + all_selected_sequences = {} + for category in categories: + category_output_dir = osp.join(args.output_dir, category) + os.makedirs(category_output_dir, exist_ok=True) + category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(category_selected_sequences_path): + with open(category_selected_sequences_path, 'r') as fid: + category_selected_sequences = json.load(fid) + else: + print(f"Processing {split} - category = {category}") + category_selected_sequences = prepare_sequences( + category=category, + co3d_dir=args.co3d_dir, + output_dir=args.output_dir, + img_size=args.img_size, + split=split, + min_quality=args.min_quality, + max_num_sequences_per_object=args.num_sequences_per_object, + seed=args.seed + CATEGORIES_IDX[category], + is_single_sequence_subset=args.single_sequence_subset + ) + with open(category_selected_sequences_path, 'w') as file: + json.dump(category_selected_sequences, file) + + all_selected_sequences[category] = category_selected_sequences + with open(selected_sequences_path, 'w') as file: + json.dump(all_selected_sequences, file) diff --git a/imcui/third_party/dust3r/datasets_preprocess/preprocess_megadepth.py b/imcui/third_party/dust3r/datasets_preprocess/preprocess_megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..b07c0c5dff0cfd828f9ce4fd204cf2eaa22487f1 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/preprocess_megadepth.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the MegaDepth dataset +# dataset at https://www.cs.cornell.edu/projects/megadepth/ +# -------------------------------------------------------- +import os +import os.path as osp +import collections +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 +import h5py + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--megadepth_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/megadepth_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + os.makedirs(output_dir, exist_ok=True) + + # load all pairs + data = np.load(pairs_path, allow_pickle=True) + scenes = data['scenes'] + images = data['images'] + pairs = data['pairs'] + + # enumerate all unique images + todo = collections.defaultdict(set) + for scene, im1, im2, score in pairs: + todo[scene].add(im1) + todo[scene].add(im2) + + # for each scene, load intrinsics and then parallel crops + for scene, im_idxs in tqdm(todo.items(), desc='Overall'): + scene, subscene = scenes[scene].split() + out_dir = osp.join(output_dir, scene, subscene) + os.makedirs(out_dir, exist_ok=True) + + # load all camera params + _, pose_w2cam, intrinsics = _load_kpts_and_poses(db_root, scene, subscene, intrinsics=True) + + in_dir = osp.join(db_root, scene, 'dense' + subscene) + args = [(in_dir, img, intrinsics[img], pose_w2cam[img], out_dir) + for img in [images[im_id] for im_id in im_idxs]] + parallel_threads(resize_one_image, args, star_args=True, front_num=0, leave=False, desc=f'{scene}/{subscene}') + + # save pairs + print('Done! prepared all pairs in', output_dir) + + +def resize_one_image(root, tag, K_pre_rectif, pose_w2cam, out_dir): + if osp.isfile(osp.join(out_dir, tag + '.npz')): + return + + # load image + img = cv2.cvtColor(cv2.imread(osp.join(root, 'imgs', tag), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + H, W = img.shape[:2] + + # load depth + with h5py.File(osp.join(root, 'depths', osp.splitext(tag)[0] + '.h5'), 'r') as hd5: + depthmap = np.asarray(hd5['depth']) + + # rectify = undistort the intrinsics + imsize_pre, K_pre, distortion = K_pre_rectif + imsize_post = img.shape[1::-1] + K_post = cv2.getOptimalNewCameraMatrix(K_pre, distortion, imsize_pre, alpha=0, + newImgSize=imsize_post, centerPrincipalPoint=True)[0] + + # downscale + img_out, depthmap_out, intrinsics_out, R_in2out = _downscale_image(K_post, img, depthmap, resolution_out=(800, 600)) + + # write everything + img_out.save(osp.join(out_dir, tag + '.jpg'), quality=90) + cv2.imwrite(osp.join(out_dir, tag + '.exr'), depthmap_out) + + camout2world = np.linalg.inv(pose_w2cam) + camout2world[:3, :3] = camout2world[:3, :3] @ R_in2out.T + np.savez(osp.join(out_dir, tag + '.npz'), intrinsics=intrinsics_out, cam2world=camout2world) + + +def _downscale_image(camera_intrinsics, image, depthmap, resolution_out=(512, 384)): + H, W = image.shape[:2] + resolution_out = sorted(resolution_out)[::+1 if W < H else -1] + + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + image, depthmap, camera_intrinsics, resolution_out, force=False) + R_in2out = np.eye(3) + + return image, depthmap, intrinsics_out, R_in2out + + +def _load_kpts_and_poses(root, scene_id, subscene, z_only=False, intrinsics=False): + if intrinsics: + with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'cameras.txt'), 'r') as f: + raw = f.readlines()[3:] # skip the header + + camera_intrinsics = {} + for camera in raw: + camera = camera.split(' ') + width, height, focal, cx, cy, k0 = [float(elem) for elem in camera[2:]] + K = np.eye(3) + K[0, 0] = focal + K[1, 1] = focal + K[0, 2] = cx + K[1, 2] = cy + camera_intrinsics[int(camera[0])] = ((int(width), int(height)), K, (k0, 0, 0, 0)) + + with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'images.txt'), 'r') as f: + raw = f.read().splitlines()[4:] # skip the header + + extract_pose = colmap_raw_pose_to_principal_axis if z_only else colmap_raw_pose_to_RT + + poses = {} + points3D_idxs = {} + camera = [] + + for image, points in zip(raw[:: 2], raw[1:: 2]): + image = image.split(' ') + points = points.split(' ') + + image_id = image[-1] + camera.append(int(image[-2])) + + # find the principal axis + raw_pose = [float(elem) for elem in image[1: -2]] + poses[image_id] = extract_pose(raw_pose) + + current_points3D_idxs = {int(i) for i in points[2:: 3] if i != '-1'} + assert -1 not in current_points3D_idxs, bb() + points3D_idxs[image_id] = current_points3D_idxs + + if intrinsics: + image_intrinsics = {im_id: camera_intrinsics[cam] for im_id, cam in zip(poses, camera)} + return points3D_idxs, poses, image_intrinsics + else: + return points3D_idxs, poses + + +def colmap_raw_pose_to_principal_axis(image_pose): + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + z_axis = np.float32([ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ]) + return z_axis + + +def colmap_raw_pose_to_RT(image_pose): + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array([ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ] + ]) + # principal_axis.append(R[2, :]) + t = image_pose[4: 7] + # World-to-Camera pose + current_pose = np.eye(4) + current_pose[: 3, : 3] = R + current_pose[: 3, 3] = t + return current_pose + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.megadepth_dir, args.precomputed_pairs, args.output_dir) diff --git a/imcui/third_party/dust3r/datasets_preprocess/preprocess_scannetpp.py b/imcui/third_party/dust3r/datasets_preprocess/preprocess_scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..03f2ff44a76b0d89011d8092e4dc395233f4d7bd --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/preprocess_scannetpp.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the scannet++ dataset. +# Usage: +# python3 datasets_preprocess/preprocess_scannetpp.py --scannetpp_dir /path/to/scannetpp --precomputed_pairs /path/to/scannetpp_pairs --pyopengl-platform egl +# -------------------------------------------------------- +import os +import argparse +import os.path as osp +import re +from tqdm import tqdm +import json +from scipy.spatial.transform import Rotation +import pyrender +import trimesh +import trimesh.exchange.ply +import numpy as np +import cv2 +import PIL.Image as Image + +from dust3r.datasets.utils.cropping import rescale_image_depthmap +import dust3r.utils.geometry as geometry + +inv = np.linalg.inv +norm = np.linalg.norm +REGEXPR_DSLR = re.compile(r'^DSC(?P\d+).JPG$') +REGEXPR_IPHONE = re.compile(r'frame_(?P\d+).jpg$') + +DEBUG_VIZ = None # 'iou' +if DEBUG_VIZ is not None: + import matplotlib.pyplot as plt # noqa + + +OPENGL_TO_OPENCV = np.float32([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--scannetpp_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/scannetpp_processed') + parser.add_argument('--target_resolution', default=920, type=int, help="images resolution") + parser.add_argument('--pyopengl-platform', type=str, default='', help='PyOpenGL env variable') + return parser + + +def pose_from_qwxyz_txyz(elems): + qw, qx, qy, qz, tx, ty, tz = map(float, elems) + pose = np.eye(4) + pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() + pose[:3, 3] = (tx, ty, tz) + return np.linalg.inv(pose) # returns cam2world + + +def get_frame_number(name, cam_type='dslr'): + if cam_type == 'dslr': + regex_expr = REGEXPR_DSLR + elif cam_type == 'iphone': + regex_expr = REGEXPR_IPHONE + else: + raise NotImplementedError(f'wrong {cam_type=} for get_frame_number') + matches = re.match(regex_expr, name) + return matches['frameid'] + + +def load_sfm(sfm_dir, cam_type='dslr'): + # load cameras + with open(osp.join(sfm_dir, 'cameras.txt'), 'r') as f: + raw = f.read().splitlines()[3:] # skip header + + intrinsics = {} + for camera in tqdm(raw, position=1, leave=False): + camera = camera.split(' ') + intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]] + + # load images + with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + img_idx = {} + img_infos = {} + for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2, position=1, leave=False): + image = image.split(' ') + points = points.split(' ') + + idx = image[0] + img_name = image[-1] + assert img_name not in img_idx, 'duplicate db image: ' + img_name + img_idx[img_name] = idx # register image name + + current_points2D = {int(i): (float(x), float(y)) + for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} + img_infos[idx] = dict(intrinsics=intrinsics[int(image[-2])], + path=img_name, + frame_id=get_frame_number(img_name, cam_type), + cam_to_world=pose_from_qwxyz_txyz(image[1: -2]), + sparse_pts2d=current_points2D) + + # load 3D points + with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + points3D = {} + observations = {idx: [] for idx in img_infos.keys()} + for point in tqdm(raw, position=1, leave=False): + point = point.split() + point_3d_idx = int(point[0]) + points3D[point_3d_idx] = tuple(map(float, point[1:4])) + if len(point) > 8: + for idx, point_2d_idx in zip(point[8::2], point[9::2]): + observations[idx].append((point_3d_idx, int(point_2d_idx))) + + return img_idx, img_infos, points3D, observations + + +def subsample_img_infos(img_infos, num_images, allowed_name_subset=None): + img_infos_val = [(idx, val) for idx, val in img_infos.items()] + if allowed_name_subset is not None: + img_infos_val = [(idx, val) for idx, val in img_infos_val if val['path'] in allowed_name_subset] + + if len(img_infos_val) > num_images: + img_infos_val = sorted(img_infos_val, key=lambda x: x[1]['frame_id']) + kept_idx = np.round(np.linspace(0, len(img_infos_val) - 1, num_images)).astype(int).tolist() + img_infos_val = [img_infos_val[idx] for idx in kept_idx] + return {idx: val for idx, val in img_infos_val} + + +def undistort_images(intrinsics, rgb, mask): + camera_type = intrinsics[0] + + width = int(intrinsics[1]) + height = int(intrinsics[2]) + fx = intrinsics[3] + fy = intrinsics[4] + cx = intrinsics[5] + cy = intrinsics[6] + distortion = np.array(intrinsics[7:]) + + K = np.zeros([3, 3]) + K[0, 0] = fx + K[0, 2] = cx + K[1, 1] = fy + K[1, 2] = cy + K[2, 2] = 1 + + K = geometry.colmap_to_opencv_intrinsics(K) + if camera_type == "OPENCV_FISHEYE": + assert len(distortion) == 4 + + new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( + K, + distortion, + (width, height), + np.eye(3), + balance=0.0, + ) + # Make the cx and cy to be the center of the image + new_K[0, 2] = width / 2.0 + new_K[1, 2] = height / 2.0 + + map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) + else: + new_K, _ = cv2.getOptimalNewCameraMatrix(K, distortion, (width, height), 1, (width, height), True) + map1, map2 = cv2.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) + + undistorted_image = cv2.remap(rgb, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) + undistorted_mask = cv2.remap(mask, map1, map2, interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, borderValue=255) + new_K = geometry.opencv_to_colmap_intrinsics(new_K) + return width, height, new_K, undistorted_image, undistorted_mask + + +def process_scenes(root, pairsdir, output_dir, target_resolution): + os.makedirs(output_dir, exist_ok=True) + + # default values from + # https://github.com/scannetpp/scannetpp/blob/main/common/configs/render.yml + znear = 0.05 + zfar = 20.0 + + listfile = osp.join(pairsdir, 'scene_list.json') + with open(listfile, 'r') as f: + scenes = json.load(f) + + # for each of these, we will select some dslr images and some iphone images + # we will undistort them and render their depth + renderer = pyrender.OffscreenRenderer(0, 0) + for scene in tqdm(scenes, position=0, leave=True): + data_dir = os.path.join(root, 'data', scene) + dir_dslr = os.path.join(data_dir, 'dslr') + dir_iphone = os.path.join(data_dir, 'iphone') + dir_scans = os.path.join(data_dir, 'scans') + + assert os.path.isdir(data_dir) and os.path.isdir(dir_dslr) \ + and os.path.isdir(dir_iphone) and os.path.isdir(dir_scans) + + output_dir_scene = os.path.join(output_dir, scene) + scene_metadata_path = osp.join(output_dir_scene, 'scene_metadata.npz') + if osp.isfile(scene_metadata_path): + continue + + pairs_dir_scene = os.path.join(pairsdir, scene) + pairs_dir_scene_selected_pairs = os.path.join(pairs_dir_scene, 'selected_pairs.npz') + assert osp.isfile(pairs_dir_scene_selected_pairs) + selected_npz = np.load(pairs_dir_scene_selected_pairs) + selection, pairs = selected_npz['selection'], selected_npz['pairs'] + + # set up the output paths + output_dir_scene_rgb = os.path.join(output_dir_scene, 'images') + output_dir_scene_depth = os.path.join(output_dir_scene, 'depth') + os.makedirs(output_dir_scene_rgb, exist_ok=True) + os.makedirs(output_dir_scene_depth, exist_ok=True) + + ply_path = os.path.join(dir_scans, 'mesh_aligned_0.05.ply') + + sfm_dir_dslr = os.path.join(dir_dslr, 'colmap') + rgb_dir_dslr = os.path.join(dir_dslr, 'resized_images') + mask_dir_dslr = os.path.join(dir_dslr, 'resized_anon_masks') + + sfm_dir_iphone = os.path.join(dir_iphone, 'colmap') + rgb_dir_iphone = os.path.join(dir_iphone, 'rgb') + mask_dir_iphone = os.path.join(dir_iphone, 'rgb_masks') + + # load the mesh + with open(ply_path, 'rb') as f: + mesh_kwargs = trimesh.exchange.ply.load_ply(f) + mesh_scene = trimesh.Trimesh(**mesh_kwargs) + + # read colmap reconstruction, we will only use the intrinsics and pose here + img_idx_dslr, img_infos_dslr, points3D_dslr, observations_dslr = load_sfm(sfm_dir_dslr, cam_type='dslr') + dslr_paths = { + "in_colmap": sfm_dir_dslr, + "in_rgb": rgb_dir_dslr, + "in_mask": mask_dir_dslr, + } + + img_idx_iphone, img_infos_iphone, points3D_iphone, observations_iphone = load_sfm( + sfm_dir_iphone, cam_type='iphone') + iphone_paths = { + "in_colmap": sfm_dir_iphone, + "in_rgb": rgb_dir_iphone, + "in_mask": mask_dir_iphone, + } + + mesh = pyrender.Mesh.from_trimesh(mesh_scene, smooth=False) + pyrender_scene = pyrender.Scene() + pyrender_scene.add(mesh) + + selection_dslr = [imgname + '.JPG' for imgname in selection if imgname.startswith('DSC')] + selection_iphone = [imgname + '.jpg' for imgname in selection if imgname.startswith('frame_')] + + # resize the image to a more manageable size and render depth + for selection_cam, img_idx, img_infos, paths_data in [(selection_dslr, img_idx_dslr, img_infos_dslr, dslr_paths), + (selection_iphone, img_idx_iphone, img_infos_iphone, iphone_paths)]: + rgb_dir = paths_data['in_rgb'] + mask_dir = paths_data['in_mask'] + for imgname in tqdm(selection_cam, position=1, leave=False): + imgidx = img_idx[imgname] + img_infos_idx = img_infos[imgidx] + rgb = np.array(Image.open(os.path.join(rgb_dir, img_infos_idx['path']))) + mask = np.array(Image.open(os.path.join(mask_dir, img_infos_idx['path'][:-3] + 'png'))) + + _, _, K, rgb, mask = undistort_images(img_infos_idx['intrinsics'], rgb, mask) + + # rescale_image_depthmap assumes opencv intrinsics + intrinsics = geometry.colmap_to_opencv_intrinsics(K) + image, mask, intrinsics = rescale_image_depthmap( + rgb, mask, intrinsics, (target_resolution, target_resolution * 3.0 / 4)) + + W, H = image.size + intrinsics = geometry.opencv_to_colmap_intrinsics(intrinsics) + + # update inpace img_infos_idx + img_infos_idx['intrinsics'] = intrinsics + rgb_outpath = os.path.join(output_dir_scene_rgb, img_infos_idx['path'][:-3] + 'jpg') + image.save(rgb_outpath) + + depth_outpath = os.path.join(output_dir_scene_depth, img_infos_idx['path'][:-3] + 'png') + # render depth image + renderer.viewport_width, renderer.viewport_height = W, H + fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2] + camera = pyrender.camera.IntrinsicsCamera(fx, fy, cx, cy, znear=znear, zfar=zfar) + camera_node = pyrender_scene.add(camera, pose=img_infos_idx['cam_to_world'] @ OPENGL_TO_OPENCV) + + depth = renderer.render(pyrender_scene, flags=pyrender.RenderFlags.DEPTH_ONLY) + pyrender_scene.remove_node(camera_node) # dont forget to remove camera + + depth = (depth * 1000).astype('uint16') + # invalidate depth from mask before saving + depth_mask = (mask < 255) + depth[depth_mask] = 0 + Image.fromarray(depth).save(depth_outpath) + + trajectories = [] + intrinsics = [] + for imgname in selection: + if imgname.startswith('DSC'): + imgidx = img_idx_dslr[imgname + '.JPG'] + img_infos_idx = img_infos_dslr[imgidx] + elif imgname.startswith('frame_'): + imgidx = img_idx_iphone[imgname + '.jpg'] + img_infos_idx = img_infos_iphone[imgidx] + else: + raise ValueError('invalid image name') + + intrinsics.append(img_infos_idx['intrinsics']) + trajectories.append(img_infos_idx['cam_to_world']) + + intrinsics = np.stack(intrinsics, axis=0) + trajectories = np.stack(trajectories, axis=0) + # save metadata for this scene + np.savez(scene_metadata_path, + trajectories=trajectories, + intrinsics=intrinsics, + images=selection, + pairs=pairs) + + del img_infos + del pyrender_scene + + # concat all scene_metadata.npz into a single file + scene_data = {} + for scene_subdir in scenes: + scene_metadata_path = osp.join(output_dir, scene_subdir, 'scene_metadata.npz') + with np.load(scene_metadata_path) as data: + trajectories = data['trajectories'] + intrinsics = data['intrinsics'] + images = data['images'] + pairs = data['pairs'] + scene_data[scene_subdir] = {'trajectories': trajectories, + 'intrinsics': intrinsics, + 'images': images, + 'pairs': pairs} + + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + pairs = [] + for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): + num_imgs = data['images'].shape[0] + img_pairs = data['pairs'] + + scenes.append(scene_subdir) + sceneids.extend([scene_idx] * num_imgs) + + images.append(data['images']) + + intrinsics.append(data['intrinsics']) + trajectories.append(data['trajectories']) + + # offset pairs + img_pairs[:, 0:2] += offset + pairs.append(img_pairs) + counts.append(offset) + + offset += num_imgs + + images = np.concatenate(images, axis=0) + intrinsics = np.concatenate(intrinsics, axis=0) + trajectories = np.concatenate(trajectories, axis=0) + pairs = np.concatenate(pairs, axis=0) + np.savez(osp.join(output_dir, 'all_metadata.npz'), + counts=counts, + scenes=scenes, + sceneids=sceneids, + images=images, + intrinsics=intrinsics, + trajectories=trajectories, + pairs=pairs) + print('all done') + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + if args.pyopengl_platform.strip(): + os.environ['PYOPENGL_PLATFORM'] = args.pyopengl_platform + process_scenes(args.scannetpp_dir, args.precomputed_pairs, args.output_dir, args.target_resolution) diff --git a/imcui/third_party/dust3r/datasets_preprocess/preprocess_staticthings3d.py b/imcui/third_party/dust3r/datasets_preprocess/preprocess_staticthings3d.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3eec16321c14b12291699f1fee492b5a7d8b1c --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/preprocess_staticthings3d.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the StaticThings3D dataset +# dataset at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d +# 1) Download StaticThings3D in /path/to/StaticThings3D/ +# with the script at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/scripts/download_staticthings3d.sh +# --> depths.tar.bz2 frames_finalpass.tar.bz2 poses.tar.bz2 frames_cleanpass.tar.bz2 intrinsics.tar.bz2 +# 2) unzip everything in the same /path/to/StaticThings3D/ directory +# 5) python datasets_preprocess/preprocess_staticthings3d.py --StaticThings3D_dir /path/to/tmp/StaticThings3D/ +# -------------------------------------------------------- +import os +import os.path as osp +import re +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--StaticThings3D_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/staticthings3d_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + all_scenes = _list_all_scenes(db_root) + + # crop images + args = [(db_root, osp.join(split, subsplit, seq), camera, f'{n:04d}', output_dir) + for split, subsplit, seq in all_scenes for camera in ['left', 'right'] for n in range(6, 16)] + parallel_threads(load_crop_and_save, args, star_args=True, front_num=1) + + # verify that all images are there + CAM = {b'l': 'left', b'r': 'right'} + pairs = np.load(pairs_path) + for scene, seq, cam1, im1, cam2, im2 in tqdm(pairs): + seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') + for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: + for ext in ['clean', 'final']: + impath = osp.join(output_dir, seq_path, cam, f"{idx:04n}_{ext}.jpg") + assert osp.isfile(impath), f'missing an image at {impath=}' + + print(f'>> Saved all data to {output_dir}!') + + +def load_crop_and_save(db_root, relpath_, camera, num, out_dir): + relpath = osp.join(relpath_, camera, num) + if osp.isfile(osp.join(out_dir, relpath + '.npz')): + return + os.makedirs(osp.join(out_dir, relpath_, camera), exist_ok=True) + + # load everything + intrinsics_in = readFloat(osp.join(db_root, 'intrinsics', relpath_, num + '.float3')) + cam2world = np.linalg.inv(readFloat(osp.join(db_root, 'poses', relpath + '.float3'))) + depthmap_in = readFloat(osp.join(db_root, 'depths', relpath + '.float3')) + img_clean = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_cleanpass', + relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + img_final = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_finalpass', + relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + + # do the crop + assert img_clean.shape[:2] == (540, 960) + assert img_final.shape[:2] == (540, 960) + (clean_out, final_out), depthmap, intrinsics_out, R_in2out = _crop_image( + intrinsics_in, (img_clean, img_final), depthmap_in, (512, 384)) + + # write everything + clean_out.save(osp.join(out_dir, relpath + '_clean.jpg'), quality=80) + final_out.save(osp.join(out_dir, relpath + '_final.jpg'), quality=80) + cv2.imwrite(osp.join(out_dir, relpath + '.exr'), depthmap) + + # New camera parameters + cam2world[:3, :3] = cam2world[:3, :3] @ R_in2out.T + np.savez(osp.join(out_dir, relpath + '.npz'), intrinsics=intrinsics_out, cam2world=cam2world) + + +def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(512, 512)): + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + color_image_in, depthmap_in, intrinsics_in, resolution_out) + R_in2out = np.eye(3) + return image, depthmap, intrinsics_out, R_in2out + + +def _list_all_scenes(path): + print('>> Listing all scenes') + + res = [] + for split in ['TRAIN']: + for subsplit in 'ABC': + for seq in os.listdir(osp.join(path, 'intrinsics', split, subsplit)): + res.append((split, subsplit, seq)) + print(f' (found ({len(res)}) scenes)') + assert res, f'Did not find anything at {path=}' + return res + + +def readFloat(name): + with open(name, 'rb') as f: + if (f.readline().decode("utf-8")) != 'float\n': + raise Exception('float file %s did not contain keyword' % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + data = np.fromfile(f, np.float32, count).reshape(dims) + return data # Hxw or CxHxW NxCxHxW + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.StaticThings3D_dir, args.precomputed_pairs, args.output_dir) diff --git a/imcui/third_party/dust3r/datasets_preprocess/preprocess_waymo.py b/imcui/third_party/dust3r/datasets_preprocess/preprocess_waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..203f337330a7e06e61d2fb9dd99647063967922d --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/preprocess_waymo.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the WayMo Open dataset +# dataset at https://github.com/waymo-research/waymo-open-dataset +# 1) Accept the license +# 2) download all training/*.tfrecord files from Perception Dataset, version 1.4.2 +# 3) put all .tfrecord files in '/path/to/waymo_dir' +# 4) install the waymo_open_dataset package with +# `python3 -m pip install gcsfs waymo-open-dataset-tf-2-12-0==1.6.4` +# 5) execute this script as `python preprocess_waymo.py --waymo_dir /path/to/waymo_dir` +# -------------------------------------------------------- +import sys +import os +import os.path as osp +import shutil +import json +from tqdm import tqdm +import PIL.Image +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import tensorflow.compat.v1 as tf +tf.enable_eager_execution() + +import path_to_root # noqa +from dust3r.utils.geometry import geotrf, inv +from dust3r.utils.image import imread_cv2 +from dust3r.utils.parallel import parallel_processes as parallel_map +from dust3r.datasets.utils import cropping +from dust3r.viz import show_raw_pointcloud + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--waymo_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/waymo_processed') + parser.add_argument('--workers', type=int, default=1) + return parser + + +def main(waymo_root, pairs_path, output_dir, workers=1): + extract_frames(waymo_root, output_dir, workers=workers) + make_crops(output_dir, workers=args.workers) + + # make sure all pairs are there + with np.load(pairs_path) as data: + scenes = data['scenes'] + frames = data['frames'] + pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) + + for scene_id, im1_id, im2_id in pairs: + for im_id in (im1_id, im2_id): + path = osp.join(output_dir, scenes[scene_id], frames[im_id] + '.jpg') + assert osp.isfile(path), f'Missing a file at {path=}\nDid you download all .tfrecord files?' + + shutil.rmtree(osp.join(output_dir, 'tmp')) + print('Done! all data generated at', output_dir) + + +def _list_sequences(db_root): + print('>> Looking for sequences in', db_root) + res = sorted(f for f in os.listdir(db_root) if f.endswith('.tfrecord')) + print(f' found {len(res)} sequences') + return res + + +def extract_frames(db_root, output_dir, workers=8): + sequences = _list_sequences(db_root) + output_dir = osp.join(output_dir, 'tmp') + print('>> outputing result to', output_dir) + args = [(db_root, output_dir, seq) for seq in sequences] + parallel_map(process_one_seq, args, star_args=True, workers=workers) + + +def process_one_seq(db_root, output_dir, seq): + out_dir = osp.join(output_dir, seq) + os.makedirs(out_dir, exist_ok=True) + calib_path = osp.join(out_dir, 'calib.json') + if osp.isfile(calib_path): + return + + try: + with tf.device('/CPU:0'): + calib, frames = extract_frames_one_seq(osp.join(db_root, seq)) + except RuntimeError: + print(f'/!\\ Error with sequence {seq} /!\\', file=sys.stderr) + return # nothing is saved + + for f, (frame_name, views) in enumerate(tqdm(frames, leave=False)): + for cam_idx, view in views.items(): + img = PIL.Image.fromarray(view.pop('img')) + img.save(osp.join(out_dir, f'{f:05d}_{cam_idx}.jpg')) + np.savez(osp.join(out_dir, f'{f:05d}_{cam_idx}.npz'), **view) + + with open(calib_path, 'w') as f: + json.dump(calib, f) + + +def extract_frames_one_seq(filename): + from waymo_open_dataset import dataset_pb2 as open_dataset + from waymo_open_dataset.utils import frame_utils + + print('>> Opening', filename) + dataset = tf.data.TFRecordDataset(filename, compression_type='') + + calib = None + frames = [] + + for data in tqdm(dataset, leave=False): + frame = open_dataset.Frame() + frame.ParseFromString(bytearray(data.numpy())) + + content = frame_utils.parse_range_image_and_camera_projection(frame) + range_images, camera_projections, _, range_image_top_pose = content + + views = {} + frames.append((frame.context.name, views)) + + # once in a sequence, read camera calibration info + if calib is None: + calib = [] + for cam in frame.context.camera_calibrations: + calib.append((cam.name, + dict(width=cam.width, + height=cam.height, + intrinsics=list(cam.intrinsic), + extrinsics=list(cam.extrinsic.transform)))) + + # convert LIDAR to pointcloud + points, cp_points = frame_utils.convert_range_image_to_point_cloud( + frame, + range_images, + camera_projections, + range_image_top_pose) + + # 3d points in vehicle frame. + points_all = np.concatenate(points, axis=0) + cp_points_all = np.concatenate(cp_points, axis=0) + + # The distance between lidar points and vehicle frame origin. + cp_points_all_tensor = tf.constant(cp_points_all, dtype=tf.int32) + + for i, image in enumerate(frame.images): + # select relevant 3D points for this view + mask = tf.equal(cp_points_all_tensor[..., 0], image.name) + cp_points_msk_tensor = tf.cast(tf.gather_nd(cp_points_all_tensor, tf.where(mask)), dtype=tf.float32) + + pose = np.asarray(image.pose.transform).reshape(4, 4) + timestamp = image.pose_timestamp + + rgb = tf.image.decode_jpeg(image.image).numpy() + + pix = cp_points_msk_tensor[..., 1:3].numpy().round().astype(np.int16) + pts3d = points_all[mask.numpy()] + + views[image.name] = dict(img=rgb, pose=pose, pixels=pix, pts3d=pts3d, timestamp=timestamp) + + if not 'show full point cloud': + show_raw_pointcloud([v['pts3d'] for v in views.values()], [v['img'] for v in views.values()]) + + return calib, frames + + +def make_crops(output_dir, workers=16, **kw): + tmp_dir = osp.join(output_dir, 'tmp') + sequences = _list_sequences(tmp_dir) + args = [(tmp_dir, output_dir, seq) for seq in sequences] + parallel_map(crop_one_seq, args, star_args=True, workers=workers, front_num=0) + + +def crop_one_seq(input_dir, output_dir, seq, resolution=512): + seq_dir = osp.join(input_dir, seq) + out_dir = osp.join(output_dir, seq) + if osp.isfile(osp.join(out_dir, '00100_1.jpg')): + return + os.makedirs(out_dir, exist_ok=True) + + # load calibration file + try: + with open(osp.join(seq_dir, 'calib.json')) as f: + calib = json.load(f) + except IOError: + print(f'/!\\ Error: Missing calib.json in sequence {seq} /!\\', file=sys.stderr) + return + + axes_transformation = np.array([ + [0, -1, 0, 0], + [0, 0, -1, 0], + [1, 0, 0, 0], + [0, 0, 0, 1]]) + + cam_K = {} + cam_distortion = {} + cam_res = {} + cam_to_car = {} + for cam_idx, cam_info in calib: + cam_idx = str(cam_idx) + cam_res[cam_idx] = (W, H) = (cam_info['width'], cam_info['height']) + f1, f2, cx, cy, k1, k2, p1, p2, k3 = cam_info['intrinsics'] + cam_K[cam_idx] = np.asarray([(f1, 0, cx), (0, f2, cy), (0, 0, 1)]) + cam_distortion[cam_idx] = np.asarray([k1, k2, p1, p2, k3]) + cam_to_car[cam_idx] = np.asarray(cam_info['extrinsics']).reshape(4, 4) # cam-to-vehicle + + frames = sorted(f[:-3] for f in os.listdir(seq_dir) if f.endswith('.jpg')) + + # from dust3r.viz import SceneViz + # viz = SceneViz() + + for frame in tqdm(frames, leave=False): + cam_idx = frame[-2] # cam index + assert cam_idx in '12345', f'bad {cam_idx=} in {frame=}' + data = np.load(osp.join(seq_dir, frame + 'npz')) + car_to_world = data['pose'] + W, H = cam_res[cam_idx] + + # load depthmap + pos2d = data['pixels'].round().astype(np.uint16) + x, y = pos2d.T + pts3d = data['pts3d'] # already in the car frame + pts3d = geotrf(axes_transformation @ inv(cam_to_car[cam_idx]), pts3d) + # X=LEFT_RIGHT y=ALTITUDE z=DEPTH + + # load image + image = imread_cv2(osp.join(seq_dir, frame + 'jpg')) + + # downscale image + output_resolution = (resolution, 1) if W > H else (1, resolution) + image, _, intrinsics2 = cropping.rescale_image_depthmap(image, None, cam_K[cam_idx], output_resolution) + image.save(osp.join(out_dir, frame + 'jpg'), quality=80) + + # save as an EXR file? yes it's smaller (and easier to load) + W, H = image.size + depthmap = np.zeros((H, W), dtype=np.float32) + pos2d = geotrf(intrinsics2 @ inv(cam_K[cam_idx]), pos2d).round().astype(np.int16) + x, y = pos2d.T + depthmap[y.clip(min=0, max=H - 1), x.clip(min=0, max=W - 1)] = pts3d[:, 2] + cv2.imwrite(osp.join(out_dir, frame + 'exr'), depthmap) + + # save camera parametes + cam2world = car_to_world @ cam_to_car[cam_idx] @ inv(axes_transformation) + np.savez(osp.join(out_dir, frame + 'npz'), intrinsics=intrinsics2, + cam2world=cam2world, distortion=cam_distortion[cam_idx]) + + # viz.add_rgbd(np.asarray(image), depthmap, intrinsics2, cam2world) + # viz.show() + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.waymo_dir, args.precomputed_pairs, args.output_dir, workers=args.workers) diff --git a/imcui/third_party/dust3r/datasets_preprocess/preprocess_wildrgbd.py b/imcui/third_party/dust3r/datasets_preprocess/preprocess_wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..ff3f0f7abb7d9ef43bba6a7c6cd6f4e652a8f510 --- /dev/null +++ b/imcui/third_party/dust3r/datasets_preprocess/preprocess_wildrgbd.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the WildRGB-D dataset. +# Usage: +# python3 datasets_preprocess/preprocess_wildrgbd.py --wildrgbd_dir /path/to/wildrgbd +# -------------------------------------------------------- + +import argparse +import random +import json +import os +import os.path as osp + +import PIL.Image +import numpy as np +import cv2 + +from tqdm.auto import tqdm +import matplotlib.pyplot as plt + +import path_to_root # noqa +import dust3r.datasets.utils.cropping as cropping # noqa +from dust3r.utils.image import imread_cv2 + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, default="data/wildrgbd_processed") + parser.add_argument("--wildrgbd_dir", type=str, required=True) + parser.add_argument("--train_num_sequences_per_object", type=int, default=50) + parser.add_argument("--test_num_sequences_per_object", type=int, default=10) + parser.add_argument("--num_frames", type=int, default=100) + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--img_size", type=int, default=512, + help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size")) + return parser + + +def get_set_list(category_dir, split): + listfiles = ["camera_eval_list.json", "nvs_list.json"] + + sequences_all = {s: {k: set() for k in listfiles} for s in ['train', 'val']} + for listfile in listfiles: + with open(osp.join(category_dir, listfile)) as f: + subset_lists_data = json.load(f) + for s in ['train', 'val']: + sequences_all[s][listfile].update(subset_lists_data[s]) + train_intersection = set.intersection(*list(sequences_all['train'].values())) + if split == "train": + return train_intersection + else: + all_seqs = set.union(*list(sequences_all['train'].values()), *list(sequences_all['val'].values())) + return all_seqs.difference(train_intersection) + + +def prepare_sequences(category, wildrgbd_dir, output_dir, img_size, split, max_num_sequences_per_object, + output_num_frames, seed): + random.seed(seed) + category_dir = osp.join(wildrgbd_dir, category) + category_output_dir = osp.join(output_dir, category) + sequences_all = get_set_list(category_dir, split) + sequences_all = sorted(sequences_all) + + sequences_all_tmp = [] + for seq_name in sequences_all: + scene_dir = osp.join(wildrgbd_dir, category_dir, seq_name) + if not os.path.isdir(scene_dir): + print(f'{scene_dir} does not exist, skipped') + continue + sequences_all_tmp.append(seq_name) + sequences_all = sequences_all_tmp + if len(sequences_all) <= max_num_sequences_per_object: + selected_sequences = sequences_all + else: + selected_sequences = random.sample(sequences_all, max_num_sequences_per_object) + + selected_sequences_numbers_dict = {} + for seq_name in tqdm(selected_sequences, leave=False): + scene_dir = osp.join(category_dir, seq_name) + scene_output_dir = osp.join(category_output_dir, seq_name) + with open(osp.join(scene_dir, 'metadata'), 'r') as f: + metadata = json.load(f) + + K = np.array(metadata["K"]).reshape(3, 3).T + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + w, h = metadata["w"], metadata["h"] + + camera_intrinsics = np.array( + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + ) + camera_to_world_path = os.path.join(scene_dir, 'cam_poses.txt') + camera_to_world_content = np.genfromtxt(camera_to_world_path) + camera_to_world = camera_to_world_content[:, 1:].reshape(-1, 4, 4) + + frame_idx = camera_to_world_content[:, 0] + num_frames = frame_idx.shape[0] + assert num_frames >= output_num_frames + assert np.all(frame_idx == np.arange(num_frames)) + + # selected_sequences_numbers_dict[seq_name] = num_frames + + selected_frames = np.round(np.linspace(0, num_frames - 1, output_num_frames)).astype(int).tolist() + selected_sequences_numbers_dict[seq_name] = selected_frames + + for frame_id in tqdm(selected_frames): + depth_path = os.path.join(scene_dir, 'depth', f'{frame_id:0>5d}.png') + masks_path = os.path.join(scene_dir, 'masks', f'{frame_id:0>5d}.png') + rgb_path = os.path.join(scene_dir, 'rgb', f'{frame_id:0>5d}.png') + + input_rgb_image = PIL.Image.open(rgb_path).convert('RGB') + input_mask = plt.imread(masks_path) + input_depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float64) + depth_mask = np.stack((input_depthmap, input_mask), axis=-1) + H, W = input_depthmap.shape + + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = int(cx - min_margin_x), int(cy - min_margin_y) + r, b = int(cx + min_margin_x), int(cy + min_margin_y) + crop_bbox = (l, t, r, b) + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( + input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) + + # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 + scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + if max(output_resolution) < img_size: + # let's put the max dimension to img_size + scale_final = (img_size / max(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( + input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) + input_depthmap = depth_mask[:, :, 0] + input_mask = depth_mask[:, :, 1] + + camera_pose = camera_to_world[frame_id] + + # save crop images and depth, metadata + save_img_path = os.path.join(scene_output_dir, 'rgb', f'{frame_id:0>5d}.jpg') + save_depth_path = os.path.join(scene_output_dir, 'depth', f'{frame_id:0>5d}.png') + save_mask_path = os.path.join(scene_output_dir, 'masks', f'{frame_id:0>5d}.png') + os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) + + input_rgb_image.save(save_img_path) + cv2.imwrite(save_depth_path, input_depthmap.astype(np.uint16)) + cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) + + save_meta_path = os.path.join(scene_output_dir, 'metadata', f'{frame_id:0>5d}.npz') + os.makedirs(os.path.split(save_meta_path)[0], exist_ok=True) + np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, + camera_pose=camera_pose) + + return selected_sequences_numbers_dict + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + assert args.wildrgbd_dir != args.output_dir + + categories = sorted([ + dirname for dirname in os.listdir(args.wildrgbd_dir) + if os.path.isdir(os.path.join(args.wildrgbd_dir, dirname, 'scenes')) + ]) + + os.makedirs(args.output_dir, exist_ok=True) + + splits_num_sequences_per_object = [args.train_num_sequences_per_object, args.test_num_sequences_per_object] + for split, num_sequences_per_object in zip(['train', 'test'], splits_num_sequences_per_object): + selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(selected_sequences_path): + continue + all_selected_sequences = {} + for category in categories: + category_output_dir = osp.join(args.output_dir, category) + os.makedirs(category_output_dir, exist_ok=True) + category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(category_selected_sequences_path): + with open(category_selected_sequences_path, 'r') as fid: + category_selected_sequences = json.load(fid) + else: + print(f"Processing {split} - category = {category}") + category_selected_sequences = prepare_sequences( + category=category, + wildrgbd_dir=args.wildrgbd_dir, + output_dir=args.output_dir, + img_size=args.img_size, + split=split, + max_num_sequences_per_object=num_sequences_per_object, + output_num_frames=args.num_frames, + seed=args.seed + int("category".encode('ascii').hex(), 16), + ) + with open(category_selected_sequences_path, 'w') as file: + json.dump(category_selected_sequences, file) + + all_selected_sequences[category] = category_selected_sequences + with open(selected_sequences_path, 'w') as file: + json.dump(all_selected_sequences, file) diff --git a/imcui/third_party/dust3r/demo.py b/imcui/third_party/dust3r/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..326c6e5a49d5d352b4afb5445cee5d22571c3bdd --- /dev/null +++ b/imcui/third_party/dust3r/demo.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dust3r gradio demo executable +# -------------------------------------------------------- +import os +import torch +import tempfile + +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.demo import get_args_parser, main_demo, set_print_with_timestamp + +import matplotlib.pyplot as pl +pl.ion() + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + set_print_with_timestamp() + + if args.tmp_dir is not None: + tmp_path = args.tmp_dir + os.makedirs(tmp_path, exist_ok=True) + tempfile.tempdir = tmp_path + + if args.server_name is not None: + server_name = args.server_name + else: + server_name = '0.0.0.0' if args.local_network else '127.0.0.1' + + if args.weights is not None: + weights_path = args.weights + else: + weights_path = "naver/" + args.model_name + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + # dust3r will write the 3D model inside tmpdirname + with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: + if not args.silent: + print('Outputing stuff in', tmpdirname) + main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent) diff --git a/imcui/third_party/dust3r/docker/docker-compose-cpu.yml b/imcui/third_party/dust3r/docker/docker-compose-cpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..2015fd771e8b6246d288c03a38f6fbb3f17dff20 --- /dev/null +++ b/imcui/third_party/dust3r/docker/docker-compose-cpu.yml @@ -0,0 +1,16 @@ +version: '3.8' +services: + dust3r-demo: + build: + context: ./files + dockerfile: cpu.Dockerfile + ports: + - "7860:7860" + volumes: + - ./files/checkpoints:/dust3r/checkpoints + environment: + - DEVICE=cpu + - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth} + cap_add: + - IPC_LOCK + - SYS_RESOURCE diff --git a/imcui/third_party/dust3r/docker/docker-compose-cuda.yml b/imcui/third_party/dust3r/docker/docker-compose-cuda.yml new file mode 100644 index 0000000000000000000000000000000000000000..85710af953d669fe618273de6ce3a062a7a84cca --- /dev/null +++ b/imcui/third_party/dust3r/docker/docker-compose-cuda.yml @@ -0,0 +1,23 @@ +version: '3.8' +services: + dust3r-demo: + build: + context: ./files + dockerfile: cuda.Dockerfile + ports: + - "7860:7860" + environment: + - DEVICE=cuda + - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth} + volumes: + - ./files/checkpoints:/dust3r/checkpoints + cap_add: + - IPC_LOCK + - SYS_RESOURCE + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/imcui/third_party/dust3r/dust3r/__init__.py b/imcui/third_party/dust3r/dust3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/dust3r/dust3r/cloud_opt/__init__.py b/imcui/third_party/dust3r/dust3r/cloud_opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faf5cd279a317c1efb9ba947682992c0949c1bdc --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/cloud_opt/__init__.py @@ -0,0 +1,33 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# global alignment optimization wrapper function +# -------------------------------------------------------- +from enum import Enum + +from .optimizer import PointCloudOptimizer +from .modular_optimizer import ModularPointCloudOptimizer +from .pair_viewer import PairViewer + + +class GlobalAlignerMode(Enum): + PointCloudOptimizer = "PointCloudOptimizer" + ModularPointCloudOptimizer = "ModularPointCloudOptimizer" + PairViewer = "PairViewer" + + +def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): + # extract all inputs + view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] + # build the optimizer + if mode == GlobalAlignerMode.PointCloudOptimizer: + net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: + net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.PairViewer: + net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) + else: + raise NotImplementedError(f'Unknown mode {mode}') + + return net diff --git a/imcui/third_party/dust3r/dust3r/cloud_opt/base_opt.py b/imcui/third_party/dust3r/dust3r/cloud_opt/base_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..4d36e05bfca80509bced20add7c067987d538951 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/cloud_opt/base_opt.py @@ -0,0 +1,405 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class for the global alignement procedure +# -------------------------------------------------------- +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn +import roma +from copy import deepcopy +import tqdm + +from dust3r.utils.geometry import inv, geotrf +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb +from dust3r.viz import SceneViz, segment_sky, auto_cam_size +from dust3r.optim_factory import adjust_learning_rate_by_lr + +from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, + cosine_schedule, linear_schedule, get_conf_trf) +import dust3r.cloud_opt.init_im_poses as init_fun + + +class BasePCOptimizer (nn.Module): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + other = deepcopy(args[0]) + attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes + min_conf_thr conf_thr conf_i conf_j im_conf + base_scale norm_pw_scale POSE_DIM pw_poses + pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() + self.__dict__.update({k: other[k] for k in attrs}) + else: + self._init_from_views(*args, **kwargs) + + def _init_from_views(self, view1, view2, pred1, pred2, + dist='l1', + conf='log', + min_conf_thr=3, + base_scale=0.5, + allow_pw_adaptors=False, + pw_break=20, + rand_pose=torch.randn, + iterationsCount=None, + verbose=True): + super().__init__() + if not isinstance(view1['idx'], list): + view1['idx'] = view1['idx'].tolist() + if not isinstance(view2['idx'], list): + view2['idx'] = view2['idx'].tolist() + self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} + self.dist = ALL_DISTS[dist] + self.verbose = verbose + + self.n_imgs = self._check_edges() + + # input data + pred1_pts = pred1['pts3d'] + pred2_pts = pred2['pts3d_in_other_view'] + self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) + self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) + self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) + + # work in log-scale with conf + pred1_conf = pred1['conf'] + pred2_conf = pred2['conf'] + self.min_conf_thr = min_conf_thr + self.conf_trf = get_conf_trf(conf) + + self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}) + self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}) + self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) + for i in range(len(self.im_conf)): + self.im_conf[i].requires_grad = False + + # pairwise pose parameters + self.base_scale = base_scale + self.norm_pw_scale = True + self.pw_break = pw_break + self.POSE_DIM = 7 + self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses + self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation + self.pw_adaptors.requires_grad_(allow_pw_adaptors) + self.has_im_poses = False + self.rand_pose = rand_pose + + # possibly store images for show_pointcloud + self.imgs = None + if 'img' in view1 and 'img' in view2: + imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] + for v in range(len(self.edges)): + idx = view1['idx'][v] + imgs[idx] = view1['img'][v] + idx = view2['idx'][v] + imgs[idx] = view2['img'][v] + self.imgs = rgb(imgs) + + @property + def n_edges(self): + return len(self.edges) + + @property + def str_edges(self): + return [edge_str(i, j) for i, j in self.edges] + + @property + def imsizes(self): + return [(w, h) for h, w in self.imshapes] + + @property + def device(self): + return next(iter(self.parameters())).device + + def state_dict(self, trainable=True): + all_params = super().state_dict() + return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} + + def load_state_dict(self, data): + return super().load_state_dict(self.state_dict(trainable=False) | data) + + def _check_edges(self): + indices = sorted({i for edge in self.edges for i in edge}) + assert indices == list(range(len(indices))), 'bad pair indices: missing values ' + return len(indices) + + @torch.no_grad() + def _compute_img_conf(self, pred1_conf, pred2_conf): + im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) + for e, (i, j) in enumerate(self.edges): + im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) + im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) + return im_conf + + def get_adaptors(self): + adapt = self.pw_adaptors + adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) + if self.norm_pw_scale: # normalize so that the product == 1 + adapt = adapt - adapt.mean(dim=1, keepdim=True) + return (adapt / self.pw_break).exp() + + def _get_poses(self, poses): + # normalize rotation + Q = poses[:, :4] + T = signed_expm1(poses[:, 4:7]) + RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() + return RT + + def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): + # all poses == cam-to-world + pose = poses[idx] + if not (pose.requires_grad or force): + return pose + + if R.shape == (4, 4): + assert T is None + T = R[:3, 3] + R = R[:3, :3] + + if R is not None: + pose.data[0:4] = roma.rotmat_to_unitquat(R) + if T is not None: + pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale + + if scale is not None: + assert poses.shape[-1] in (8, 13) + pose.data[-1] = np.log(float(scale)) + return pose + + def get_pw_norm_scale_factor(self): + if self.norm_pw_scale: + # normalize scales so that things cannot go south + # we want that exp(scale) ~= self.base_scale + return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() + else: + return 1 # don't norm scale for known poses + + def get_pw_scale(self): + scale = self.pw_poses[:, -1].exp() # (n_edges,) + scale = scale * self.get_pw_norm_scale_factor() + return scale + + def get_pw_poses(self): # cam to world + RT = self._get_poses(self.pw_poses) + scaled_RT = RT.clone() + scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation + return scaled_RT + + def get_masks(self): + return [(conf > self.min_conf_thr) for conf in self.im_conf] + + def depth_to_pts3d(self): + raise NotImplementedError() + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def _set_focal(self, idx, focal, force=False): + raise NotImplementedError() + + def get_focals(self): + raise NotImplementedError() + + def get_known_focal_mask(self): + raise NotImplementedError() + + def get_principal_points(self): + raise NotImplementedError() + + def get_conf(self, mode=None): + trf = self.conf_trf if mode is None else get_conf_trf(mode) + return [trf(c) for c in self.im_conf] + + def get_im_poses(self): + raise NotImplementedError() + + def _set_depthmap(self, idx, depth, force=False): + raise NotImplementedError() + + def get_depthmaps(self, raw=False): + raise NotImplementedError() + + def clean_pointcloud(self, **kw): + cams = inv(self.get_im_poses()) + K = self.get_intrinsics() + depthmaps = self.get_depthmaps() + all_pts3d = self.get_pts3d() + + new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw) + + for i, new_conf in enumerate(new_im_confs): + self.im_conf[i].data[:] = new_conf + return self + + def forward(self, ret_details=False): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors() + proj_pts3d = self.get_pts3d() + # pre-compute pixel weights + weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} + weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} + + loss = 0 + if ret_details: + details = -torch.ones((self.n_imgs, self.n_imgs)) + + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # distance in image i and j + aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) + aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) + li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() + lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() + loss = loss + li + lj + + if ret_details: + details[i, j] = li + lj + loss /= self.n_edges # average over all pairs + + if ret_details: + return loss, details + return loss + + @torch.cuda.amp.autocast(enabled=False) + def compute_global_alignment(self, init=None, niter_PnP=10, **kw): + if init is None: + pass + elif init == 'msp' or init == 'mst': + init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) + elif init == 'known_poses': + init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, + niter_PnP=niter_PnP) + else: + raise ValueError(f'bad value for {init=}') + + return global_alignment_loop(self, **kw) + + @torch.no_grad() + def mask_sky(self): + res = deepcopy(self) + for i in range(self.n_imgs): + sky = segment_sky(self.imgs[i]) + res.im_conf[i][sky] = 0 + return res + + def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): + viz = SceneViz() + if self.imgs is None: + colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) + colors = list(map(tuple, colors.tolist())) + for n in range(self.n_imgs): + viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) + else: + viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) + colors = np.random.randint(256, size=(self.n_imgs, 3)) + + # camera poses + im_poses = to_numpy(self.get_im_poses()) + if cam_size is None: + cam_size = auto_cam_size(im_poses) + viz.add_cameras(im_poses, self.get_focals(), colors=colors, + images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) + if show_pw_cams: + pw_poses = self.get_pw_poses() + viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) + + if show_pw_pts3d: + pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] + viz.add_pointcloud(pts, (128, 0, 128)) + + viz.show(**kw) + return viz + + +def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6): + params = [p for p in net.parameters() if p.requires_grad] + if not params: + return net + + verbose = net.verbose + if verbose: + print('Global alignement - optimizing for:') + print([name for name, value in net.named_parameters() if value.requires_grad]) + + lr_base = lr + optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) + + loss = float('inf') + if verbose: + with tqdm.tqdm(total=niter) as bar: + while bar.n < bar.total: + loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule) + bar.set_postfix_str(f'{lr=:g} loss={loss:g}') + bar.update() + else: + for n in range(niter): + loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule) + return loss + + +def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): + t = cur_iter / niter + if schedule == 'cosine': + lr = cosine_schedule(t, lr_base, lr_min) + elif schedule == 'linear': + lr = linear_schedule(t, lr_base, lr_min) + else: + raise ValueError(f'bad lr {schedule=}') + adjust_learning_rate_by_lr(optimizer, lr) + optimizer.zero_grad() + loss = net() + loss.backward() + optimizer.step() + + return float(loss), lr + + +@torch.no_grad() +def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, + tol=0.001, bad_conf=0, dbg=()): + """ Method: + 1) express all 3d points in each camera coordinate frame + 2) if they're in front of a depthmap --> then lower their confidence + """ + assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) + assert 0 <= tol < 1 + res = [c.clone() for c in im_confs] + + # reshape appropriately + all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)] + depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)] + + for i, pts3d in enumerate(all_pts3d): + for j in range(len(all_pts3d)): + if i == j: continue + + # project 3dpts in other view + proj = geotrf(cams[j], pts3d) + proj_depth = proj[:,:,2] + u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) + + # check which points are actually in the visible cone + H, W = im_confs[j].shape + msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) + msk_j = v[msk_i], u[msk_i] + + # find bad points = those in front but less confident + bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j]) + + bad_msk_i = msk_i.clone() + bad_msk_i[msk_i] = bad_points + res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) + + return res diff --git a/imcui/third_party/dust3r/dust3r/cloud_opt/commons.py b/imcui/third_party/dust3r/dust3r/cloud_opt/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..3be9f855a69ea18c82dcc8e5769e0149a59649bd --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/cloud_opt/commons.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utility functions for global alignment +# -------------------------------------------------------- +import torch +import torch.nn as nn +import numpy as np + + +def edge_str(i, j): + return f'{i}_{j}' + + +def i_j_ij(ij): + return edge_str(*ij), ij + + +def edge_conf(conf_i, conf_j, edge): + return float(conf_i[edge].mean() * conf_j[edge].mean()) + + +def compute_edge_scores(edges, conf_i, conf_j): + return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} + + +def NoGradParamDict(x): + assert isinstance(x, dict) + return nn.ParameterDict(x).requires_grad_(False) + + +def get_imshapes(edges, pred_i, pred_j): + n_imgs = max(max(e) for e in edges) + 1 + imshapes = [None] * n_imgs + for e, (i, j) in enumerate(edges): + shape_i = tuple(pred_i[e].shape[0:2]) + shape_j = tuple(pred_j[e].shape[0:2]) + if imshapes[i]: + assert imshapes[i] == shape_i, f'incorrect shape for image {i}' + if imshapes[j]: + assert imshapes[j] == shape_j, f'incorrect shape for image {j}' + imshapes[i] = shape_i + imshapes[j] = shape_j + return imshapes + + +def get_conf_trf(mode): + if mode == 'log': + def conf_trf(x): return x.log() + elif mode == 'sqrt': + def conf_trf(x): return x.sqrt() + elif mode == 'm1': + def conf_trf(x): return x-1 + elif mode in ('id', 'none'): + def conf_trf(x): return x + else: + raise ValueError(f'bad mode for {mode=}') + return conf_trf + + +def l2_dist(a, b, weight): + return ((a - b).square().sum(dim=-1) * weight) + + +def l1_dist(a, b, weight): + return ((a - b).norm(dim=-1) * weight) + + +ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) + + +def signed_log1p(x): + sign = torch.sign(x) + return sign * torch.log1p(torch.abs(x)) + + +def signed_expm1(x): + sign = torch.sign(x) + return sign * torch.expm1(torch.abs(x)) + + +def cosine_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 + + +def linear_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_start + (lr_end - lr_start) * t diff --git a/imcui/third_party/dust3r/dust3r/cloud_opt/init_im_poses.py b/imcui/third_party/dust3r/dust3r/cloud_opt/init_im_poses.py new file mode 100644 index 0000000000000000000000000000000000000000..7887c5cde27115273601e704b81ca0b0301f3715 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/cloud_opt/init_im_poses.py @@ -0,0 +1,316 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Initialization functions for global alignment +# -------------------------------------------------------- +from functools import cache + +import numpy as np +import scipy.sparse as sp +import torch +import cv2 +import roma +from tqdm import tqdm + +from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses +from dust3r.post_process import estimate_focal_knowing_depth +from dust3r.viz import to_numpy + +from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores + + +@torch.no_grad() +def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3): + device = self.device + + # indices of known poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + assert nkp == self.n_imgs, 'not all poses are known' + + # get all focals + nkf, _, im_focals = get_known_focals(self) + assert nkf == self.n_imgs + im_pp = self.get_principal_points() + + best_depthmaps = {} + # init all pairwise poses + for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)): + i_j = edge_str(i, j) + + # find relative pose for this pair + P1 = torch.eye(4, device=device) + msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1) + _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()), + pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP) + + # align the two predicted camera with the two gt cameras + s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]]) + # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1 + # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # remember if this is a good depthmap + score = float(self.conf_i[i_j].mean()) + if score > best_depthmaps.get(i, (0,))[0]: + best_depthmaps[i] = score, i_j, s + + # init all image poses + for n in range(self.n_imgs): + assert known_poses_msk[n] + _, i_j, scale = best_depthmaps[n] + depth = self.pred_i[i_j][:, :, 2] + self._set_depthmap(n, depth * scale) + + +@torch.no_grad() +def init_minimum_spanning_tree(self, **kw): + """ Init all camera poses (image-wise and pairwise poses) given + an initial set of pairwise estimations. + """ + device = self.device + pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges, + self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr, + device, has_im_poses=self.has_im_poses, verbose=self.verbose, + **kw) + + return init_from_pts3d(self, pts3d, im_focals, im_poses) + + +def init_from_pts3d(self, pts3d, im_focals, im_poses): + # init poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + if nkp == 1: + raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose") + elif nkp > 1: + # global rigid SE3 alignment + s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk]) + trf = sRT_to_4x4(s, R, T, device=known_poses.device) + + # rotate everything + im_poses = trf @ im_poses + im_poses[:, :3, :3] /= s # undo scaling on the rotation part + for img_pts3d in pts3d: + img_pts3d[:] = geotrf(trf, img_pts3d) + + # set all pairwise poses + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # compute transform that goes from cam to world + s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # take into account the scale normalization + s_factor = self.get_pw_norm_scale_factor() + im_poses[:, :3, 3] *= s_factor # apply downscaling factor + for img_pts3d in pts3d: + img_pts3d *= s_factor + + # init all image poses + if self.has_im_poses: + for i in range(self.n_imgs): + cam2world = im_poses[i] + depth = geotrf(inv(cam2world), pts3d[i])[..., 2] + self._set_depthmap(i, depth) + self._set_pose(self.im_poses, i, cam2world) + if im_focals[i] is not None: + self._set_focal(i, im_focals[i]) + + if self.verbose: + print(' init loss =', float(self())) + + +def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr, + device, has_im_poses=True, niter_PnP=10, verbose=True): + n_imgs = len(imshapes) + sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)) + msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() + + # temp variable to store 3d points + pts3d = [None] * len(imshapes) + + todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges + im_poses = [None] * n_imgs + im_focals = [None] * n_imgs + + # init with strongest edge + score, i, j = todo.pop() + if verbose: + print(f' init edge ({i}*,{j}*) {score=}') + i_j = edge_str(i, j) + pts3d[i] = pred_i[i_j].clone() + pts3d[j] = pred_j[i_j].clone() + done = {i, j} + if has_im_poses: + im_poses[i] = torch.eye(4, device=device) + im_focals[i] = estimate_focal(pred_i[i_j]) + + # set initial pointcloud based on pairwise graph + msp_edges = [(i, j)] + while todo: + # each time, predict the next one + score, i, j = todo.pop() + + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[i_j]) + + if i in done: + if verbose: + print(f' init edge ({i},{j}*) {score=}') + assert j not in done + # align pred[i] with pts3d[i], and then set j accordingly + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[j] = geotrf(trf, pred_j[i_j]) + done.add(j) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + + elif j in done: + if verbose: + print(f' init edge ({i}*,{j}) {score=}') + assert i not in done + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[i] = geotrf(trf, pred_i[i_j]) + done.add(i) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + else: + # let's try again later + todo.insert(0, (score, i, j)) + + if has_im_poses: + # complete all missing informations + pair_scores = list(sparse_graph.values()) # already negative scores: less is best + edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)] + for i, j in edges_from_best_to_worse.tolist(): + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[edge_str(i, j)]) + + for i in range(n_imgs): + if im_poses[i] is None: + msk = im_conf[i] > min_conf_thr + res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP) + if res: + im_focals[i], im_poses[i] = res + if im_poses[i] is None: + im_poses[i] = torch.eye(4, device=device) + im_poses = torch.stack(im_poses) + else: + im_poses = im_focals = None + + return pts3d, msp_edges, im_focals, im_poses + + +def dict_to_sparse_graph(dic): + n_imgs = max(max(e) for e in dic) + 1 + res = sp.dok_array((n_imgs, n_imgs)) + for edge, value in dic.items(): + res[edge] = value + return res + + +def rigid_points_registration(pts1, pts2, conf): + R, T, s = roma.rigid_points_registration( + pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True) + return s, R, T # return un-scaled (R, T) + + +def sRT_to_4x4(scale, R, T, device): + trf = torch.eye(4, device=device) + trf[:3, :3] = R * scale + trf[:3, 3] = T.ravel() # doesn't need scaling + return trf + + +def estimate_focal(pts3d_i, pp=None): + if pp is None: + H, W, THREE = pts3d_i.shape + assert THREE == 3 + pp = torch.tensor((W/2, H/2), device=pts3d_i.device) + focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel() + return float(focal) + + +@cache +def pixel_grid(H, W): + return np.mgrid[:W, :H].T.astype(np.float32) + + +def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): + # extract camera poses and focals with RANSAC-PnP + if msk.sum() < 4: + return None # we need at least 4 points for PnP + pts3d, msk = map(to_numpy, (pts3d, msk)) + + H, W, THREE = pts3d.shape + assert THREE == 3 + pixels = pixel_grid(H, W) + + if focal is None: + S = max(W, H) + tentative_focals = np.geomspace(S/2, S*3, 21) + else: + tentative_focals = [focal] + + if pp is None: + pp = (W/2, H/2) + else: + pp = to_numpy(pp) + + best = 0, + for focal in tentative_focals: + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + if not success: + continue + + score = len(inliers) + if success and score > best[0]: + best = score, R, T, focal + + if not best[0]: + return None + + _, R, T, best_focal = best + R = cv2.Rodrigues(R)[0] # world to cam + R, T = map(torch.from_numpy, (R, T)) + return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world + + +def get_known_poses(self): + if self.has_im_poses: + known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses]) + known_poses = self.get_im_poses() + return known_poses_msk.sum(), known_poses_msk, known_poses + else: + return 0, None, None + + +def get_known_focals(self): + if self.has_im_poses: + known_focal_msk = self.get_known_focal_mask() + known_focals = self.get_focals() + return known_focal_msk.sum(), known_focal_msk, known_focals + else: + return 0, None, None + + +def align_multiple_poses(src_poses, target_poses): + N = len(src_poses) + assert src_poses.shape == target_poses.shape == (N, 4, 4) + + def center_and_z(poses): + eps = get_med_dist_between_poses(poses) / 100 + return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2])) + R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True) + return s, R, T diff --git a/imcui/third_party/dust3r/dust3r/cloud_opt/modular_optimizer.py b/imcui/third_party/dust3r/dust3r/cloud_opt/modular_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d06464b40276684385c18b9195be1491c6f47f07 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/cloud_opt/modular_optimizer.py @@ -0,0 +1,145 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import geotrf +from dust3r.utils.device import to_cpu, to_numpy +from dust3r.utils.geometry import depthmap_to_pts3d + + +class ModularPointCloudOptimizer (BasePCOptimizer): + """ Optimize a global scene, given a list of pairwise observations. + Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics) + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs): + super().__init__(*args, **kwargs) + self.has_im_poses = True # by definition of this class + self.focal_brake = focal_brake + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) + self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses + default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes] + self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [ + f]) for f in default_focals) # camera intrinsics + self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f' (setting pose #{idx} = {pose[:3,3]})') + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True)) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = (n_known_poses <= 1) + + def preset_intrinsics(self, known_intrinsics, msk=None): + if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2: + known_intrinsics = [known_intrinsics] + for K in known_intrinsics: + assert K.shape == (3, 3) + self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk) + self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk) + + def preset_focal(self, known_focals, msk=None): + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f' (setting focal #{idx} = {focal})') + self._no_grad(self._set_focal(idx, focal, force=True)) + + def preset_principal_point(self, known_pp, msk=None): + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f' (setting principal point #{idx} = {pp})') + self._no_grad(self._set_principal_point(idx, pp, force=True)) + + def _no_grad(self, tensor): + return tensor.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f'bad {msk=}') + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = self.focal_brake * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_brake).exp() + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 + return param + + def get_principal_points(self): + return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)]) + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().view(self.n_imgs, -1) + K[:, 0, 0] = focals[:, 0] + K[:, 1, 1] = focals[:, -1] + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(torch.stack(list(self.im_poses))) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self): + return [d.exp() for d in self.im_depthmaps] + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps() + + # convert focal to (1,2,H,W) constant field + def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i]) + # get pointmaps in camera frame + rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])] + # project to world frame + return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)] + + def get_pts3d(self): + return self.depth_to_pts3d() diff --git a/imcui/third_party/dust3r/dust3r/cloud_opt/optimizer.py b/imcui/third_party/dust3r/dust3r/cloud_opt/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..42e48613e55faa4ede5a366d1c0bfc4d18ffae4f --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/cloud_opt/optimizer.py @@ -0,0 +1,248 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Main class for the implementation of the global alignment +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import xy_grid, geotrf +from dust3r.utils.device import to_cpu, to_numpy + + +class PointCloudOptimizer(BasePCOptimizer): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs): + super().__init__(*args, **kwargs) + + self.has_im_poses = True # by definition of this class + self.focal_break = focal_break + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) + self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses + self.im_focals = nn.ParameterList(torch.FloatTensor( + [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics + self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + self.imshape = self.imshapes[0] + im_areas = [h*w for h, w in self.imshapes] + self.max_area = max(im_areas) + + # adding thing to optimize + self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area) + self.im_poses = ParameterStack(self.im_poses, is_param=True) + self.im_focals = ParameterStack(self.im_focals, is_param=True) + self.im_pp = ParameterStack(self.im_pp, is_param=True) + self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes])) + self.register_buffer('_grid', ParameterStack( + [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area)) + + # pre-compute pixel weights + self.register_buffer('_weight_i', ParameterStack( + [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area)) + self.register_buffer('_weight_j', ParameterStack( + [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area)) + + # precompute aa + self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area)) + self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area)) + self.register_buffer('_ei', torch.tensor([i for i, j in self.edges])) + self.register_buffer('_ej', torch.tensor([j for i, j in self.edges])) + self.total_area_i = sum([im_areas[i] for i, j in self.edges]) + self.total_area_j = sum([im_areas[j] for i, j in self.edges]) + + def _check_all_imgs_are_selected(self, msk): + assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!' + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + self._check_all_imgs_are_selected(pose_msk) + + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f' (setting pose #{idx} = {pose[:3,3]})') + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = (n_known_poses <= 1) + + self.im_poses.requires_grad_(False) + self.norm_pw_scale = False + + def preset_focal(self, known_focals, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f' (setting focal #{idx} = {focal})') + self._no_grad(self._set_focal(idx, focal)) + + self.im_focals.requires_grad_(False) + + def preset_principal_point(self, known_pp, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f' (setting principal point #{idx} = {pp})') + self._no_grad(self._set_principal_point(idx, pp)) + + self.im_pp.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f'bad {msk=}') + + def _no_grad(self, tensor): + assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs' + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = self.focal_break * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_break).exp() + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.im_focals]) + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 + return param + + def get_principal_points(self): + return self._pp + 10 * self.im_pp + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().flatten() + K[:, 0, 0] = K[:, 1, 1] = focals + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(self.im_poses) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + depth = _ravel_hw(depth, self.max_area) + + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self, raw=False): + res = self.im_depthmaps.exp() + if not raw: + res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps(raw=True) + + # get pointmaps in camera frame + rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) + # project to world frame + return geotrf(im_poses, rel_ptmaps) + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def forward(self): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors().unsqueeze(1) + proj_pts3d = self.get_pts3d(raw=True) + + # rotate pairwise prediction according to pw_poses + aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) + aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) + + # compute the less + li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i + lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j + + return li + lj + + +def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): + pp = pp.unsqueeze(1) + focal = focal.unsqueeze(1) + assert focal.shape == (len(depth), 1, 1) + assert pp.shape == (len(depth), 1, 2) + assert pixel_grid.shape == depth.shape + (2,) + depth = depth.unsqueeze(-1) + return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) + + +def ParameterStack(params, keys=None, is_param=None, fill=0): + if keys is not None: + params = [params[k] for k in keys] + + if fill > 0: + params = [_ravel_hw(p, fill) for p in params] + + requires_grad = params[0].requires_grad + assert all(p.requires_grad == requires_grad for p in params) + + params = torch.stack(list(params)).float().detach() + if is_param or requires_grad: + params = nn.Parameter(params) + params.requires_grad_(requires_grad) + return params + + +def _ravel_hw(tensor, fill=0): + # ravel H,W + tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) + + if len(tensor) < fill: + tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:]))) + return tensor + + +def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + return minf*focal_base, maxf*focal_base + + +def apply_mask(img, msk): + img = img.copy() + img[msk] = 0 + return img diff --git a/imcui/third_party/dust3r/dust3r/cloud_opt/pair_viewer.py b/imcui/third_party/dust3r/dust3r/cloud_opt/pair_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..62ae3b9a5fbca8b96711de051d9d6597830bd488 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/cloud_opt/pair_viewer.py @@ -0,0 +1,127 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dummy optimizer for visualizing pairs +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn +import cv2 + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates +from dust3r.cloud_opt.commons import edge_str +from dust3r.post_process import estimate_focal_knowing_depth + + +class PairViewer (BasePCOptimizer): + """ + This a Dummy Optimizer. + To use only when the goal is to visualize the results for a pair of images (with is_symmetrized) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.is_symmetrized and self.n_edges == 2 + self.has_im_poses = True + + # compute all parameters directly from raw input + self.focals = [] + self.pp = [] + rel_poses = [] + confs = [] + for i in range(self.n_imgs): + conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean()) + if self.verbose: + print(f' - {conf=:.3} for edge {i}-{1-i}') + confs.append(conf) + + H, W = self.imshapes[i] + pts3d = self.pred_i[edge_str(i, 1-i)] + pp = torch.tensor((W/2, H/2)) + focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld')) + self.focals.append(focal) + self.pp.append(pp) + + # estimate the pose of pts1 in image 2 + pixels = np.mgrid[:W, :H].T.astype(np.float32) + pts3d = self.pred_j[edge_str(1-i, i)].numpy() + assert pts3d.shape[:2] == (H, W) + msk = self.get_masks()[i].numpy() + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + try: + res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + success, R, T, inliers = res + assert success + + R = cv2.Rodrigues(R)[0] # world to cam + pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world + except: + pose = np.eye(4) + rel_poses.append(torch.from_numpy(pose.astype(np.float32))) + + # let's use the pair with the most confidence + if confs[0] > confs[1]: + # ptcloud is expressed in camera1 + self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1 + self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]] + else: + # ptcloud is expressed in camera2 + self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2 + self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]] + + self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False) + self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False) + self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False) + self.depth = nn.ParameterList(self.depth) + for p in self.parameters(): + p.requires_grad = False + + def _set_depthmap(self, idx, depth, force=False): + if self.verbose: + print('_set_depthmap is ignored in PairViewer') + return + + def get_depthmaps(self, raw=False): + depth = [d.to(self.device) for d in self.depth] + return depth + + def _set_focal(self, idx, focal, force=False): + self.focals[idx] = focal + + def get_focals(self): + return self.focals + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.focals]) + + def get_principal_points(self): + return self.pp + + def get_intrinsics(self): + focals = self.get_focals() + pps = self.get_principal_points() + K = torch.zeros((len(focals), 3, 3), device=self.device) + for i in range(len(focals)): + K[i, 0, 0] = K[i, 1, 1] = focals[i] + K[i, :2, 2] = pps[i] + K[i, 2, 2] = 1 + return K + + def get_im_poses(self): + return self.im_poses + + def depth_to_pts3d(self): + pts3d = [] + for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()): + pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(), + intrinsics.cpu().numpy(), + im_pose.cpu().numpy()) + pts3d.append(torch.from_numpy(pts).to(device=self.device)) + return pts3d + + def forward(self): + return float('nan') diff --git a/imcui/third_party/dust3r/dust3r/datasets/__init__.py b/imcui/third_party/dust3r/dust3r/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2123d09ec2840ab5ee9ca43057c35f93233bde89 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/__init__.py @@ -0,0 +1,50 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +from .utils.transforms import * +from .base.batched_sampler import BatchedRandomSampler # noqa +from .arkitscenes import ARKitScenes # noqa +from .blendedmvs import BlendedMVS # noqa +from .co3d import Co3d # noqa +from .habitat import Habitat # noqa +from .megadepth import MegaDepth # noqa +from .scannetpp import ScanNetpp # noqa +from .staticthings3d import StaticThings3D # noqa +from .waymo import Waymo # noqa +from .wildrgbd import WildRGBD # noqa + + +def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): + import torch + from croco.utils.misc import get_world_size, get_rank + + # pytorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + try: + sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, + rank=rank, drop_last=drop_last) + except (AttributeError, NotImplementedError): + # not avail for this dataset + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader diff --git a/imcui/third_party/dust3r/dust3r/datasets/arkitscenes.py b/imcui/third_party/dust3r/dust3r/datasets/arkitscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..4fad51acdc18b82cd6a4d227de0dac3b25783e33 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/arkitscenes.py @@ -0,0 +1,102 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed arkitscenes +# dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license +# See datasets_preprocess/preprocess_arkitscenes.py +# -------------------------------------------------------- +import os.path as osp +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class ARKitScenes(BaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + if split == "train": + self.split = "Training" + elif split == "test": + self.split = "Test" + else: + raise ValueError("") + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + with np.load(osp.join(self.ROOT, split, 'all_metadata.npz')) as data: + self.scenes = data['scenes'] + self.sceneids = data['sceneids'] + self.images = data['images'] + self.intrinsics = data['intrinsics'].astype(np.float32) + self.trajectories = data['trajectories'].astype(np.float32) + self.pairs = data['pairs'][:, :2].astype(int) + + def __len__(self): + return len(self.pairs) + + def _get_views(self, idx, resolution, rng): + + image_idx1, image_idx2 = self.pairs[idx] + + views = [] + for view_idx in [image_idx1, image_idx2]: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, 'vga_wide', basename.replace('.png', '.jpg'))) + # Load depthmap + depthmap = imread_cv2(osp.join(scene_dir, 'lowres_depth', basename), cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='arkitscenes', + label=self.scenes[scene_id] + '_' + basename, + instance=f'{str(idx)}_{str(view_idx)}', + )) + + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = ARKitScenes(split='train', ROOT="data/arkitscenes_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/base/__init__.py b/imcui/third_party/dust3r/dust3r/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/base/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py b/imcui/third_party/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcac01da8c27a57a7601a09c7e75754d12871e3 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py @@ -0,0 +1,220 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# base class for implementing datasets +# -------------------------------------------------------- +import PIL +import numpy as np +import torch + +from dust3r.datasets.base.easy_dataset import EasyDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates +import dust3r.datasets.utils.cropping as cropping + + +class BaseStereoViewDataset (EasyDataset): + """ Define all basic options. + + Usage: + class MyDataset (BaseStereoViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__(self, *, # only keyword arguments + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + seed=None): + self.num_views = 2 + self.split = split + self._set_resolutions(resolution) + + self.transform = transform + if isinstance(transform, str): + transform = eval(transform) + + self.aug_crop = aug_crop + self.seed = seed + + def __len__(self): + return len(self.scenes) + + def get_stats(self): + return f"{len(self)} pairs" + + def __repr__(self): + resolutions_str = '[' + ';'.join(f'{w}x{h}' for w, h in self._resolutions) + ']' + return f"""{type(self).__name__}({self.get_stats()}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '') + + def _get_views(self, idx, resolution, rng): + raise NotImplementedError() + + def __getitem__(self, idx): + if isinstance(idx, tuple): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, '_rng'): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # over-loaded code + resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng) + assert len(views) == self.num_views + + # check data-types + for v, view in enumerate(views): + assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view['idx'] = (idx, ar_idx, v) + + # encode the image + width, height = view['img'].size + view['true_shape'] = np.int32((height, width)) + view['img'] = self.transform(view['img']) + + assert 'camera_intrinsics' in view + if 'camera_pose' not in view: + view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}' + assert 'pts3d' not in view + assert 'valid_mask' not in view + assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}' + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view['pts3d'] = pts3d + view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view['camera_intrinsics'] + + # last thing done! + for view in views: + # transpose to make sure all views are the same size + transpose_to_landscape(view) + # this allows to check whether the RNG is is the same state each time + view['rng'] = int.from_bytes(self._rng.bytes(4), 'big') + return views + + def _set_resolutions(self, resolutions): + assert resolutions is not None, 'undefined resolution' + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int' + assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int' + assert width >= height + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None): + """ This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + # assert min_margin_x > W/5, f'Bad principal point in view={info}' + # assert min_margin_y > H/5, f'Bad principal point in view={info}' + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + # transpose the resolution if necessary + W, H = image.size # new size + assert resolution[0] >= resolution[1] + if H > 1.1 * W: + # image is portrait mode + resolution = resolution[::-1] + elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2): + resolution = resolution[::-1] + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + if self.aug_crop > 1: + target_resolution += rng.integers(0, self.aug_crop) + image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) + + # actual cropping (if necessary) with bilinear interpolation + intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5) + crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + return image, depthmap, intrinsics2 + + +def is_good_type(key, v): + """ returns (is_good, err_msg) + """ + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x + db = sel(view['dataset']) + label = sel(view['label']) + instance = sel(view['instance']) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view['true_shape'] + + if width < height: + # rectify portrait to landscape + assert view['img'].shape == (3, height, width) + view['img'] = view['img'].swapaxes(1, 2) + + assert view['valid_mask'].shape == (height, width) + view['valid_mask'] = view['valid_mask'].swapaxes(0, 1) + + assert view['depthmap'].shape == (height, width) + view['depthmap'] = view['depthmap'].swapaxes(0, 1) + + assert view['pts3d'].shape == (height, width, 3) + view['pts3d'] = view['pts3d'].swapaxes(0, 1) + + # transpose x and y pixels + view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]] diff --git a/imcui/third_party/dust3r/dust3r/datasets/base/batched_sampler.py b/imcui/third_party/dust3r/dust3r/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..85f58a65d41bb8101159e032d5b0aac26a7cf1a1 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/base/batched_sampler.py @@ -0,0 +1,74 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Random sampling under a constraint +# -------------------------------------------------------- +import numpy as np +import torch + + +class BatchedRandomSampler: + """ Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): + self.batch_size = batch_size + self.pool_size = pool_size + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size*world_size) if drop_last else N + assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' + + # distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + return self.total_size // self.world_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + # prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # random feat_idxs (same across each batch) + n_batches = (self.total_size+self.batch_size-1) // self.batch_size + feat_idxs = rng.integers(self.pool_size, size=n_batches) + feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) + feat_idxs = feat_idxs.ravel()[:self.total_size] + + # put them together + idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) + + # Distributed sampler: we select a subset of batches + # make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ((self.total_size + self.world_size * + self.batch_size-1) // (self.world_size * self.batch_size)) + idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple-1 + return (total//multiple) * multiple diff --git a/imcui/third_party/dust3r/dust3r/datasets/base/easy_dataset.py b/imcui/third_party/dust3r/dust3r/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4939a88f02715a1f80be943ddb6d808e1be84db7 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/base/easy_dataset.py @@ -0,0 +1,157 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# A dataset base class that you can easily resize and combine. +# -------------------------------------------------------- +import numpy as np +from dust3r.datasets.base.batched_sampler import BatchedRandomSampler + + +class EasyDataset: + """ a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) + + +class MulDataset (EasyDataset): + """ Artifically augmenting the size of a dataset. + """ + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f'{self.multiplicator}*{repr(self.dataset)}' + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[idx // self.multiplicator, other] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class ResizedDataset (EasyDataset): + """ Artifically changing the size of a dataset. + """ + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str)-1) // 3): + sep = -4*i-3 + size_str = size_str[:sep] + '_' + size_str[sep:] + return f'{size_str} @ {repr(self.dataset)}' + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch+777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) + self._idxs_mapping = shuffled_idxs[:self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def __getitem__(self, idx): + assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[self._idxs_mapping[idx], other] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class CatDataset (EasyDataset): + """ Concatenation of several datasets + """ + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other = idx + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, 'right') + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None: + new_idx = (new_idx, other) + return dataset[new_idx] + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions diff --git a/imcui/third_party/dust3r/dust3r/datasets/blendedmvs.py b/imcui/third_party/dust3r/dust3r/datasets/blendedmvs.py new file mode 100644 index 0000000000000000000000000000000000000000..93e68c28620cc47a7b1743834e45f82d576126d0 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/blendedmvs.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed BlendedMVS +# dataset at https://github.com/YoYo000/BlendedMVS +# See datasets_preprocess/preprocess_blendedmvs.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class BlendedMVS (BaseStereoViewDataset): + """ Dataset of outdoor street scenes, 5 images each time + """ + + def __init__(self, *args, ROOT, split=None, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self._load_data(split) + + def _load_data(self, split): + pairs = np.load(osp.join(self.ROOT, 'blendedmvs_pairs.npy')) + if split is None: + selection = slice(None) + if split == 'train': + # select 90% of all scenes + selection = (pairs['seq_low'] % 10) > 0 + if split == 'val': + # select 10% of all scenes + selection = (pairs['seq_low'] % 10) == 0 + self.pairs = pairs[selection] + + # list of all scenes + self.scenes = np.unique(self.pairs['seq_low']) # low is unique enough + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.scenes)} scenes' + + def _get_views(self, pair_idx, resolution, rng): + seqh, seql, img1, img2, score = self.pairs[pair_idx] + + seq = f"{seqh:08x}{seql:016x}" + seq_path = osp.join(self.ROOT, seq) + + views = [] + + for view_index in [img1, img2]: + impath = f"{view_index:08n}" + image = imread_cv2(osp.join(seq_path, impath + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) + camera_params = np.load(osp.join(seq_path, impath + ".npz")) + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params['R_cam2world'] + camera_pose[:3, 3] = camera_params['t_cam2world'] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='BlendedMVS', + label=osp.relpath(seq_path, self.ROOT), + instance=impath)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = BlendedMVS(split='train', ROOT="data/blendedmvs_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/co3d.py b/imcui/third_party/dust3r/dust3r/datasets/co3d.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea5c8555d34b776e7a48396dcd0eecece713e34 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/co3d.py @@ -0,0 +1,165 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed Co3d_v2 +# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International +# See datasets_preprocess/preprocess_co3d.py +# -------------------------------------------------------- +import os.path as osp +import json +import itertools +from collections import deque + +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class Co3d(BaseStereoViewDataset): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert mask_bg in (True, False, 'rand') + self.mask_bg = mask_bg + self.dataset_label = 'Co3d_v2' + + # load all scenes + with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f: + self.scenes = json.load(f) + self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} + self.scenes = {(k, k2): v2 for k, v in self.scenes.items() + for k2, v2 in v.items()} + self.scene_list = list(self.scenes.keys()) + + # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) + # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees + self.combinations = [(i, j) + for i, j in itertools.combinations(range(100), 2) + if 0 < abs(i - j) <= 30 and abs(i - j) % 5 == 0] + + self.invalidate = {scene: {} for scene in self.scene_list} + + def __len__(self): + return len(self.scene_list) * len(self.combinations) + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz') + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png') + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') + + def _read_depthmap(self, depthpath, input_metadata): + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth']) + return depthmap + + def _get_views(self, idx, resolution, rng): + # choose a scene + obj, instance = self.scene_list[idx // len(self.combinations)] + image_pool = self.scenes[obj, instance] + im1_idx, im2_idx = self.combinations[idx % len(self.combinations)] + + # add a bit of randomness + last = len(image_pool) - 1 + + if resolution not in self.invalidate[obj, instance]: # flag invalid images + self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] + + # decide now if we mask the bg + mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) + + views = [] + imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]] + imgs_idxs = deque(imgs_idxs) + while len(imgs_idxs) > 0: # some images (few) have zero depth + im_idx = imgs_idxs.pop() + + if self.invalidate[obj, instance][resolution][im_idx]: + # search for a valid image + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(image_pool)): + tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) + if not self.invalidate[obj, instance][resolution][tentative_im_idx]: + im_idx = tentative_im_idx + break + + view_idx = image_pool[im_idx] + + impath = self._get_impath(obj, instance, view_idx) + depthpath = self._get_depthpath(obj, instance, view_idx) + + # load camera params + metadata_path = self._get_metadatapath(obj, instance, view_idx) + input_metadata = np.load(metadata_path) + camera_pose = input_metadata['camera_pose'].astype(np.float32) + intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = self._read_depthmap(depthpath, input_metadata) + + if mask_bg: + # load object mask + maskpath = self._get_maskpath(obj, instance, view_idx) + maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) + maskmap = (maskmap / 255.0) > 0.1 + + # update the depthmap with mask + depthmap *= maskmap + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) + + num_valid = (depthmap > 0.0).sum() + if num_valid == 0: + # problem, invalidate image and retry + self.invalidate[obj, instance][resolution][im_idx] = True + imgs_idxs.append(im_idx) + continue + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/habitat.py b/imcui/third_party/dust3r/dust3r/datasets/habitat.py new file mode 100644 index 0000000000000000000000000000000000000000..11ce8a0ffb2134387d5fb794df89834db3ea8c9f --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/habitat.py @@ -0,0 +1,107 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed habitat +# dataset at https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md +# See datasets_preprocess/habitat for more details +# -------------------------------------------------------- +import os.path as osp +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 # noqa +import numpy as np +from PIL import Image +import json + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset + + +class Habitat(BaseStereoViewDataset): + def __init__(self, size, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert self.split is not None + # loading list of scenes + with open(osp.join(self.ROOT, f'Habitat_{size}_scenes_{self.split}.txt')) as f: + self.scenes = f.read().splitlines() + self.instances = list(range(1, 5)) + + def filter_scene(self, label, instance=None): + if instance: + subscene, instance = instance.split('_') + label += '/' + subscene + self.instances = [int(instance) - 1] + valid = np.bool_([scene.startswith(label) for scene in self.scenes]) + assert sum(valid), 'no scene was selected for {label=} {instance=}' + self.scenes = [scene for i, scene in enumerate(self.scenes) if valid[i]] + + def _get_views(self, idx, resolution, rng): + scene = self.scenes[idx] + data_path, key = osp.split(osp.join(self.ROOT, scene)) + views = [] + two_random_views = [0, rng.choice(self.instances)] # view 0 is connected with all other views + for view_index in two_random_views: + # load the view (and use the next one if this one's broken) + for ii in range(view_index, view_index + 5): + image, depthmap, intrinsics, camera_pose = self._load_one_view(data_path, key, ii % 5, resolution, rng) + if np.isfinite(camera_pose).all(): + break + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='Habitat', + label=osp.relpath(data_path, self.ROOT), + instance=f"{key}_{view_index}")) + return views + + def _load_one_view(self, data_path, key, view_index, resolution, rng): + view_index += 1 # file indices starts at 1 + impath = osp.join(data_path, f"{key}_{view_index}.jpeg") + image = Image.open(impath) + + depthmap_filename = osp.join(data_path, f"{key}_{view_index}_depth.exr") + depthmap = cv2.imread(depthmap_filename, cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH) + + camera_params_filename = osp.join(data_path, f"{key}_{view_index}_camera_params.json") + with open(camera_params_filename, 'r') as f: + camera_params = json.load(f) + + intrinsics = np.float32(camera_params['camera_intrinsics']) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params['R_cam2world'] + camera_pose[:3, 3] = camera_params['t_cam2world'] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=impath) + return image, depthmap, intrinsics, camera_pose + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Habitat(1_000_000, split='train', ROOT="data/habitat_processed", + resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/megadepth.py b/imcui/third_party/dust3r/dust3r/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..8131498b76d855e5293fe79b3686fc42bf87eea8 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/megadepth.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed MegaDepth +# dataset at https://www.cs.cornell.edu/projects/megadepth/ +# See datasets_preprocess/preprocess_megadepth.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class MegaDepth(BaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data(self.split) + + if self.split is None: + pass + elif self.split == 'train': + self.select_scene(('0015', '0022'), opposite=True) + elif self.split == 'val': + self.select_scene(('0015', '0022')) + else: + raise ValueError(f'bad {self.split=}') + + def _load_data(self, split): + with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: + self.all_scenes = data['scenes'] + self.all_images = data['images'] + self.pairs = data['pairs'] + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.all_scenes)} scenes' + + def select_scene(self, scene, *instances, opposite=False): + scenes = (scene,) if isinstance(scene, str) else tuple(scene) + scene_id = [s.startswith(scenes) for s in self.all_scenes] + assert any(scene_id), 'no scene found' + + valid = np.in1d(self.pairs['scene_id'], np.nonzero(scene_id)[0]) + if instances: + image_id = [i.startswith(instances) for i in self.all_images] + image_id = np.nonzero(image_id)[0] + assert len(image_id), 'no instance found' + # both together? + if len(instances) == 2: + valid &= np.in1d(self.pairs['im1_id'], image_id) & np.in1d(self.pairs['im2_id'], image_id) + else: + valid &= np.in1d(self.pairs['im1_id'], image_id) | np.in1d(self.pairs['im2_id'], image_id) + + if opposite: + valid = ~valid + assert valid.any() + self.pairs = self.pairs[valid] + + def _get_views(self, pair_idx, resolution, rng): + scene_id, im1_id, im2_id, score = self.pairs[pair_idx] + + scene, subscene = self.all_scenes[scene_id].split() + seq_path = osp.join(self.ROOT, scene, subscene) + + views = [] + + for im_id in [im1_id, im2_id]: + img = self.all_images[im_id] + try: + image = imread_cv2(osp.join(seq_path, img + '.jpg')) + depthmap = imread_cv2(osp.join(seq_path, img + ".exr")) + camera_params = np.load(osp.join(seq_path, img + ".npz")) + except Exception as e: + raise OSError(f'cannot load {img}, got exception {e}') + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.float32(camera_params['cam2world']) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, img)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='MegaDepth', + label=osp.relpath(seq_path, self.ROOT), + instance=img)) + + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = MegaDepth(split='train', ROOT="data/megadepth_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/scannetpp.py b/imcui/third_party/dust3r/dust3r/datasets/scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..520deedd0eb8cba8663af941731d89e0b2e71a80 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/scannetpp.py @@ -0,0 +1,96 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed scannet++ +# dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes +# https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf +# See datasets_preprocess/preprocess_scannetpp.py +# -------------------------------------------------------- +import os.path as osp +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class ScanNetpp(BaseStereoViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert self.split == 'train' + self.loaded_data = self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: + self.scenes = data['scenes'] + self.sceneids = data['sceneids'] + self.images = data['images'] + self.intrinsics = data['intrinsics'].astype(np.float32) + self.trajectories = data['trajectories'].astype(np.float32) + self.pairs = data['pairs'][:, :2].astype(int) + + def __len__(self): + return len(self.pairs) + + def _get_views(self, idx, resolution, rng): + + image_idx1, image_idx2 = self.pairs[idx] + + views = [] + for view_idx in [image_idx1, image_idx2]: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename + '.jpg')) + # Load depthmap + depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename + '.png'), cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='ScanNet++', + label=self.scenes[scene_id] + '_' + basename, + instance=f'{str(idx)}_{str(view_idx)}', + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = ScanNetpp(split='train', ROOT="data/scannetpp_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx*255, (1 - idx)*255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/staticthings3d.py b/imcui/third_party/dust3r/dust3r/datasets/staticthings3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f70f0ee7bf8c8ab6bb1702aa2481f3d16df413 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/staticthings3d.py @@ -0,0 +1,96 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed StaticThings3D +# dataset at https://github.com/lmb-freiburg/robustmvd/ +# See datasets_preprocess/preprocess_staticthings3d.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class StaticThings3D (BaseStereoViewDataset): + """ Dataset of indoor scenes, 5 images each time + """ + def __init__(self, ROOT, *args, mask_bg='rand', **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + + assert mask_bg in (True, False, 'rand') + self.mask_bg = mask_bg + + # loading all pairs + assert self.split is None + self.pairs = np.load(osp.join(ROOT, 'staticthings_pairs.npy')) + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs' + + def _get_views(self, pair_idx, resolution, rng): + scene, seq, cam1, im1, cam2, im2 = self.pairs[pair_idx] + seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') + + views = [] + + mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) + + CAM = {b'l':'left', b'r':'right'} + for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: + num = f"{idx:04n}" + img = num+"_clean.jpg" if rng.choice(2) else num+"_final.jpg" + image = imread_cv2(osp.join(self.ROOT, seq_path, cam, img)) + depthmap = imread_cv2(osp.join(self.ROOT, seq_path, cam, num+".exr")) + camera_params = np.load(osp.join(self.ROOT, seq_path, cam, num+".npz")) + + intrinsics = camera_params['intrinsics'] + camera_pose = camera_params['cam2world'] + + if mask_bg: + depthmap[depthmap > 200] = 0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng, info=(seq_path,cam,img)) + + views.append(dict( + img = image, + depthmap = depthmap, + camera_pose = camera_pose, # cam2world + camera_intrinsics = intrinsics, + dataset = 'StaticThings3D', + label = seq_path, + instance = cam+'_'+img)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = StaticThings3D(ROOT="data/staticthings3d_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx*255, (1 - idx)*255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/utils/__init__.py b/imcui/third_party/dust3r/dust3r/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/dust3r/dust3r/datasets/utils/cropping.py b/imcui/third_party/dust3r/dust3r/datasets/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..07a331847cb8df997b3012790f5a96f69f21464d --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/utils/cropping.py @@ -0,0 +1,124 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """ Convenience class to aply the same operation to a whole set of images. + """ + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch('resize', *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch('crop', *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True): + """ Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + # define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic) + if depthmap is not None: + depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, + fy=scale_final, interpolation=cv2.INTER_NEAREST) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/imcui/third_party/dust3r/dust3r/datasets/utils/transforms.py b/imcui/third_party/dust3r/dust3r/datasets/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..eb34f2f01d3f8f829ba71a7e03e181bf18f72c25 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/utils/transforms.py @@ -0,0 +1,11 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf +from dust3r.utils.image import ImgNorm + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) diff --git a/imcui/third_party/dust3r/dust3r/datasets/waymo.py b/imcui/third_party/dust3r/dust3r/datasets/waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a135152cd8973532405b491450c22942dcd6ca --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/waymo.py @@ -0,0 +1,93 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed WayMo +# dataset at https://github.com/waymo-research/waymo-open-dataset +# See datasets_preprocess/preprocess_waymo.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class Waymo (BaseStereoViewDataset): + """ Dataset of outdoor street scenes, 5 images each time + """ + + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, 'waymo_pairs.npz')) as data: + self.scenes = data['scenes'] + self.frames = data['frames'] + self.inv_frames = {frame: i for i, frame in enumerate(data['frames'])} + self.pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) + assert self.pairs[:, 0].max() == len(self.scenes) - 1 + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.scenes)} scenes' + + def _get_views(self, pair_idx, resolution, rng): + seq, img1, img2 = self.pairs[pair_idx] + seq_path = osp.join(self.ROOT, self.scenes[seq]) + + views = [] + + for view_index in [img1, img2]: + impath = self.frames[view_index] + image = imread_cv2(osp.join(seq_path, impath + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) + camera_params = np.load(osp.join(seq_path, impath + ".npz")) + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.float32(camera_params['cam2world']) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='Waymo', + label=osp.relpath(seq_path, self.ROOT), + instance=impath)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Waymo(split='train', ROOT="data/megadepth_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/datasets/wildrgbd.py b/imcui/third_party/dust3r/dust3r/datasets/wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..c41dd0b78402bf8ff1e62c6a50de338aa916e0af --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/datasets/wildrgbd.py @@ -0,0 +1,67 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed WildRGB-D +# dataset at https://github.com/wildrgbd/wildrgbd/ +# See datasets_preprocess/preprocess_wildrgbd.py +# -------------------------------------------------------- +import os.path as osp + +import cv2 +import numpy as np + +from dust3r.datasets.co3d import Co3d +from dust3r.utils.image import imread_cv2 + + +class WildRGBD(Co3d): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = 'WildRGBD' + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'metadata', f'{view_idx:0>5d}.npz') + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'rgb', f'{view_idx:0>5d}.jpg') + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'depth', f'{view_idx:0>5d}.png') + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'masks', f'{view_idx:0>5d}.png') + + def _read_depthmap(self, depthpath, input_metadata): + # We store depths in the depth scale of 1000. + # That is, when we load depth image and divide by 1000, we could get depth in meters. + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000.0 + return depthmap + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = WildRGBD(split='train', ROOT="data/wildrgbd_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/dust3r/dust3r/demo.py b/imcui/third_party/dust3r/dust3r/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..c491be097b71ec38ea981dadf4f456d6e9829d48 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/demo.py @@ -0,0 +1,283 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# gradio demo +# -------------------------------------------------------- +import argparse +import math +import builtins +import datetime +import gradio +import os +import torch +import numpy as np +import functools +import trimesh +import copy +from scipy.spatial.transform import Rotation + +from dust3r.inference import inference +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images, rgb +from dust3r.utils.device import to_numpy +from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode + +import matplotlib.pyplot as pl + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser_url = parser.add_mutually_exclusive_group() + parser_url.add_argument("--local_network", action='store_true', default=False, + help="make app accessible on local network: address will be set to 0.0.0.0") + parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1") + parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") + parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). " + "If None, will search for an available port starting at 7860."), + default=None) + parser_weights = parser.add_mutually_exclusive_group(required=True) + parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) + parser_weights.add_argument("--model_name", type=str, help="name of the model weights", + choices=["DUSt3R_ViTLarge_BaseDecoder_512_dpt", + "DUSt3R_ViTLarge_BaseDecoder_512_linear", + "DUSt3R_ViTLarge_BaseDecoder_224_linear"]) + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir") + parser.add_argument("--silent", action='store_true', default=False, + help="silence logs") + return parser + + +def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"): + builtin_print = builtins.print + + def print_with_timestamp(*args, **kwargs): + now = datetime.datetime.now() + formatted_date_time = now.strftime(time_format) + + builtin_print(f'[{formatted_date_time}] ', end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print_with_timestamp + + +def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, + cam_color=None, as_pointcloud=False, + transparent_cams=False, silent=False): + assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) + pts3d = to_numpy(pts3d) + imgs = to_numpy(imgs) + focals = to_numpy(focals) + cams2world = to_numpy(cams2world) + + scene = trimesh.Scene() + + # full pointcloud + if as_pointcloud: + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) + pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) + scene.add_geometry(pct) + else: + meshes = [] + for i in range(len(imgs)): + meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i])) + mesh = trimesh.Trimesh(**cat_meshes(meshes)) + scene.add_geometry(mesh) + + # add each camera + for i, pose_c2w in enumerate(cams2world): + if isinstance(cam_color, list): + camera_edge_color = cam_color[i] + else: + camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] + add_scene_cam(scene, pose_c2w, camera_edge_color, + None if transparent_cams else imgs[i], focals[i], + imsize=imgs[i].shape[1::-1], screen_width=cam_size) + + rot = np.eye(4) + rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() + scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) + outfile = os.path.join(outdir, 'scene.glb') + if not silent: + print('(exporting 3D scene to', outfile, ')') + scene.export(file_obj=outfile) + return outfile + + +def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, + clean_depth=False, transparent_cams=False, cam_size=0.05): + """ + extract 3D_model (glb file) from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + focals = scene.get_focals().cpu() + cams2world = scene.get_im_poses().cpu() + # 3D pointcloud from depthmap, poses and intrinsics + pts3d = to_numpy(scene.get_pts3d()) + scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr))) + msk = to_numpy(scene.get_masks()) + return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, + transparent_cams=transparent_cams, cam_size=cam_size, silent=silent) + + +def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, + scenegraph_type, winsize, refid): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + imgs = load_images(filelist, size=image_size, verbose=not silent) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + if scenegraph_type == "swin": + scenegraph_type = scenegraph_type + "-" + str(winsize) + elif scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) + output = inference(pairs, model, device, batch_size=1, verbose=not silent) + + mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent) + lr = 0.01 + + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) + + outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size) + + # also return rgb, depth and confidence imgs + # depth is normalized with the max value for all images + # we apply the jet colormap on the confidence maps + rgbimg = scene.imgs + depths = to_numpy(scene.get_depthmaps()) + confs = to_numpy([c for c in scene.im_conf]) + cmap = pl.get_cmap('jet') + depths_max = max([d.max() for d in depths]) + depths = [d / depths_max for d in depths] + confs_max = max([d.max() for d in confs]) + confs = [cmap(d / confs_max) for d in confs] + + imgs = [] + for i in range(len(rgbimg)): + imgs.append(rgbimg[i]) + imgs.append(rgb(depths[i])) + imgs.append(rgb(confs[i])) + + return scene, outfile, imgs + + +def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type): + num_files = len(inputfiles) if inputfiles is not None else 1 + max_winsize = max(1, math.ceil((num_files - 1) / 2)) + if scenegraph_type == "swin": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=True) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files - 1, step=1, visible=False) + elif scenegraph_type == "oneref": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files - 1, step=1, visible=True) + else: + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files - 1, step=1, visible=False) + return winsize, refid + + +def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False): + recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size) + model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent) + with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo: + # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference + scene = gradio.State(None) + gradio.HTML('

DUSt3R Demo

') + with gradio.Column(): + inputfiles = gradio.File(file_count="multiple") + with gradio.Row(): + schedule = gradio.Dropdown(["linear", "cosine"], + value='linear', label="schedule", info="For global alignment!") + niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000, + label="num_iterations", info="For global alignment!") + scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"), + ("swin: sliding window", "swin"), + ("oneref: match one image with all", "oneref")], + value='complete', label="Scenegraph", + info="Define how to make pairs", + interactive=True) + winsize = gradio.Slider(label="Scene Graph: Window Size", value=1, + minimum=1, maximum=1, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) + + run_btn = gradio.Button("Run") + + with gradio.Row(): + # adjust the confidence threshold + min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1) + # adjust the camera size in the output pointcloud + cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001) + with gradio.Row(): + as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud") + # two post process implemented + mask_sky = gradio.Checkbox(value=False, label="Mask sky") + clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") + transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") + + outmodel = gradio.Model3D() + outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%") + + # events + scenegraph_type.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + inputfiles.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + run_btn.click(fn=recon_fun, + inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, + mask_sky, clean_depth, transparent_cams, cam_size, + scenegraph_type, winsize, refid], + outputs=[scene, outmodel, outgallery]) + min_conf_thr.release(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + cam_size.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + as_pointcloud.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + mask_sky.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + clean_depth.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + transparent_cams.change(model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + demo.launch(share=False, server_name=server_name, server_port=server_port) diff --git a/imcui/third_party/dust3r/dust3r/heads/__init__.py b/imcui/third_party/dust3r/dust3r/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53d0aa5610cae95f34f96bdb3ff9e835a2d6208e --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/heads/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# head factory +# -------------------------------------------------------- +from .linear_head import LinearPts3d +from .dpt_head import create_dpt_head + + +def head_factory(head_type, output_mode, net, has_conf=False): + """" build a prediction head for the decoder + """ + if head_type == 'linear' and output_mode == 'pts3d': + return LinearPts3d(net, has_conf) + elif head_type == 'dpt' and output_mode == 'pts3d': + return create_dpt_head(net, has_conf=has_conf) + else: + raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") diff --git a/imcui/third_party/dust3r/dust3r/heads/dpt_head.py b/imcui/third_party/dust3r/dust3r/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b7bdc9ff587eef3ec8978a22f63659fbf3c277d6 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/heads/dpt_head.py @@ -0,0 +1,115 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dpt head implementation for DUST3R +# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; +# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True +# the forward function also takes as input a dictionnary img_info with key "height" and "width" +# for PixelwiseTask, the output will be of dimension B x num_channels x H x W +# -------------------------------------------------------- +from einops import rearrange +from typing import List +import torch +import torch.nn as nn +from dust3r.heads.postprocess import postprocess +import dust3r.utils.path_to_croco # noqa: F401 +from models.dpt_block import DPTOutputAdapter # noqa + + +class DPTOutputAdapter_fix(DPTOutputAdapter): + """ + Adapt croco's DPTOutputAdapter implementation for dust3r: + remove duplicated weigths, and fix forward for dust3r + """ + + def init(self, dim_tokens_enc=768): + super().init(dim_tokens_enc) + # these are duplicated weights + del self.act_1_postprocess + del self.act_2_postprocess + del self.act_3_postprocess + del self.act_4_postprocess + + def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + # H, W = input_info['image_size'] + image_size = self.image_size if image_size is None else image_size + H, W = image_size + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for dust3r, can return 3D points + confidence for all pixels""" + + def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, + output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_layers = True # backbone needs to return all layers + self.postprocess = postprocess + self.depth_mode = depth_mode + self.conf_mode = conf_mode + + assert n_cls_token == 0, "Not implemented" + dpt_args = dict(output_width_ratio=output_width_ratio, + num_channels=num_channels, + **kwargs) + if hooks_idx is not None: + dpt_args.update(hooks=hooks_idx) + self.dpt = DPTOutputAdapter_fix(**dpt_args) + dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info[0], img_info[1])) + if self.postprocess: + out = self.postprocess(out, self.depth_mode, self.conf_mode) + return out + + +def create_dpt_head(net, has_conf=False): + """ + return PixelwiseTaskWithDPT for given net params + """ + assert net.dec_depth > 9 + l2 = net.dec_depth + feature_dim = 256 + last_dim = feature_dim//2 + out_nchan = 3 + ed = net.enc_embed_dim + dd = net.dec_embed_dim + return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=[0, l2*2//4, l2*3//4, l2], + dim_tokens=[ed, dd, dd, dd], + postprocess=postprocess, + depth_mode=net.depth_mode, + conf_mode=net.conf_mode, + head_type='regression') diff --git a/imcui/third_party/dust3r/dust3r/heads/linear_head.py b/imcui/third_party/dust3r/dust3r/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6b697f29eaa6f43fad0a3e27a8d9b8f1a602a833 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/heads/linear_head.py @@ -0,0 +1,41 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# linear head implementation for DUST3R +# -------------------------------------------------------- +import torch.nn as nn +import torch.nn.functional as F +from dust3r.heads.postprocess import postprocess + + +class LinearPts3d (nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, net, has_conf=False): + super().__init__() + self.patch_size = net.patch_embed.patch_size[0] + self.depth_mode = net.depth_mode + self.conf_mode = net.conf_mode + self.has_conf = has_conf + + self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) + + def setup(self, croconet): + pass + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return postprocess(feat, self.depth_mode, self.conf_mode) diff --git a/imcui/third_party/dust3r/dust3r/heads/postprocess.py b/imcui/third_party/dust3r/dust3r/heads/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..cd68a90d89b8dcd7d8a4b4ea06ef8b17eb5da093 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/heads/postprocess.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# post process function for all heads: extract 3D points/confidence from output +# -------------------------------------------------------- +import torch + + +def postprocess(out, depth_mode, conf_mode): + """ + extract 3D points/confidence from prediction head output + """ + fmap = out.permute(0, 2, 3, 1) # B,H,W,3 + res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) + + if conf_mode is not None: + res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) + return res + + +def reg_dense_depth(xyz, mode): + """ + extract 3D points from prediction head output + """ + mode, vmin, vmax = mode + + no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) + assert no_bounds + + if mode == 'linear': + if no_bounds: + return xyz # [-inf, +inf] + return xyz.clip(min=vmin, max=vmax) + + # distance to origin + d = xyz.norm(dim=-1, keepdim=True) + xyz = xyz / d.clip(min=1e-8) + + if mode == 'square': + return xyz * d.square() + + if mode == 'exp': + return xyz * torch.expm1(d) + + raise ValueError(f'bad {mode=}') + + +def reg_dense_conf(x, mode): + """ + extract confidence from prediction head output + """ + mode, vmin, vmax = mode + if mode == 'exp': + return vmin + x.exp().clip(max=vmax-vmin) + if mode == 'sigmoid': + return (vmax - vmin) * torch.sigmoid(x) + vmin + raise ValueError(f'bad {mode=}') diff --git a/imcui/third_party/dust3r/dust3r/image_pairs.py b/imcui/third_party/dust3r/dust3r/image_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..ebcf902b4d07b83fe83ffceba3f45ca0d74dfcf7 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/image_pairs.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed to load image pairs +# -------------------------------------------------------- +import numpy as np +import torch + + +def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True): + pairs = [] + if scene_graph == 'complete': # complete graph + for i in range(len(imgs)): + for j in range(i): + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('swin'): + iscyclic = not scene_graph.endswith('noncyclic') + try: + winsize = int(scene_graph.split('-')[1]) + except Exception as e: + winsize = 3 + pairsid = set() + for i in range(len(imgs)): + for j in range(1, winsize + 1): + idx = (i + j) + if iscyclic: + idx = idx % len(imgs) # explicit loop closure + if idx >= len(imgs): + continue + pairsid.add((i, idx) if i < idx else (idx, i)) + for i, j in pairsid: + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('logwin'): + iscyclic = not scene_graph.endswith('noncyclic') + try: + winsize = int(scene_graph.split('-')[1]) + except Exception as e: + winsize = 3 + offsets = [2**i for i in range(winsize)] + pairsid = set() + for i in range(len(imgs)): + ixs_l = [i - off for off in offsets] + ixs_r = [i + off for off in offsets] + for j in ixs_l + ixs_r: + if iscyclic: + j = j % len(imgs) # Explicit loop closure + if j < 0 or j >= len(imgs) or j == i: + continue + pairsid.add((i, j) if i < j else (j, i)) + for i, j in pairsid: + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('oneref'): + refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 + for j in range(len(imgs)): + if j != refid: + pairs.append((imgs[refid], imgs[j])) + if symmetrize: + pairs += [(img2, img1) for img1, img2 in pairs] + + # now, remove edges + if isinstance(prefilter, str) and prefilter.startswith('seq'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:])) + + if isinstance(prefilter, str) and prefilter.startswith('cyc'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) + + return pairs + + +def sel(x, kept): + if isinstance(x, dict): + return {k: sel(v, kept) for k, v in x.items()} + if isinstance(x, (torch.Tensor, np.ndarray)): + return x[kept] + if isinstance(x, (tuple, list)): + return type(x)([x[k] for k in kept]) + + +def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): + # number of images + n = max(max(e) for e in edges) + 1 + + kept = [] + for e, (i, j) in enumerate(edges): + dis = abs(i - j) + if cyclic: + dis = min(dis, abs(i + n - j), abs(i - n - j)) + if dis <= seq_dis_thr: + kept.append(e) + return kept + + +def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): + edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + return [pairs[i] for i in kept] + + +def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): + edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') + return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) diff --git a/imcui/third_party/dust3r/dust3r/inference.py b/imcui/third_party/dust3r/dust3r/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..90540486b077add90ca50f62a5072e082cb2f2d7 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/inference.py @@ -0,0 +1,150 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed for the inference +# -------------------------------------------------------- +import tqdm +import torch +from dust3r.utils.device import to_cpu, collate_with_cat +from dust3r.utils.misc import invalid_to_nans +from dust3r.utils.geometry import depthmap_to_pts3d, geotrf + + +def _interleave_imgs(img1, img2): + res = {} + for key, value1 in img1.items(): + value2 = img2[key] + if isinstance(value1, torch.Tensor): + value = torch.stack((value1, value2), dim=1).flatten(0, 1) + else: + value = [x for pair in zip(value1, value2) for x in pair] + res[key] = value + return res + + +def make_batch_symmetric(batch): + view1, view2 = batch + view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) + return view1, view2 + + +def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): + view1, view2 = batch + ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng']) + for view in batch: + for name in view.keys(): # pseudo_focal + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + if symmetrize_batch: + view1, view2 = make_batch_symmetric(batch) + + with torch.cuda.amp.autocast(enabled=bool(use_amp)): + pred1, pred2 = model(view1, view2) + + # loss is supposed to be symmetric + with torch.cuda.amp.autocast(enabled=False): + loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None + + result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) + return result[ret] if ret else result + + +@torch.no_grad() +def inference(pairs, model, device, batch_size=8, verbose=True): + if verbose: + print(f'>> Inference with model on {len(pairs)} image pairs') + result = [] + + # first, check if all images have the same size + multiple_shapes = not (check_if_same_size(pairs)) + if multiple_shapes: # force bs=1 + batch_size = 1 + + for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose): + res = loss_of_one_batch(collate_with_cat(pairs[i:i + batch_size]), model, None, device) + result.append(to_cpu(res)) + + result = collate_with_cat(result, lists=multiple_shapes) + + return result + + +def check_if_same_size(pairs): + shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs] + shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs] + return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2) + + +def get_pred_pts3d(gt, pred, use_pose=False): + if 'depth' in pred and 'pseudo_focal' in pred: + try: + pp = gt['camera_intrinsics'][..., :2, 2] + except KeyError: + pp = None + pts3d = depthmap_to_pts3d(**pred, pp=pp) + + elif 'pts3d' in pred: + # pts3d from my camera + pts3d = pred['pts3d'] + + elif 'pts3d_in_other_view' in pred: + # pts3d from the other camera, already transformed + assert use_pose is True + return pred['pts3d_in_other_view'] # return! + + if use_pose: + camera_pose = pred.get('camera_pose') + assert camera_pose is not None + pts3d = geotrf(camera_pose, pts3d) + + return pts3d + + +def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None): + assert gt_pts1.ndim == pr_pts1.ndim == 4 + assert gt_pts1.shape == pr_pts1.shape + if gt_pts2 is not None: + assert gt_pts2.ndim == pr_pts2.ndim == 4 + assert gt_pts2.shape == pr_pts2.shape + + # concat the pointcloud + nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) + nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None + + pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) + pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None + + all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1 + all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 + + dot_gt_pr = (all_pr * all_gt).sum(dim=-1) + dot_gt_gt = all_gt.square().sum(dim=-1) + + if fit_mode.startswith('avg'): + # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + elif fit_mode.startswith('median'): + scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values + elif fit_mode.startswith('weiszfeld'): + # init scaling with l2 closed form + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip_(min=1e-8).reciprocal() + # update the scaling with the new weights + scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) + else: + raise ValueError(f'bad {fit_mode=}') + + if fit_mode.endswith('stop_grad'): + scaling = scaling.detach() + + scaling = scaling.clip(min=1e-3) + # assert scaling.isfinite().all(), bb() + return scaling diff --git a/imcui/third_party/dust3r/dust3r/losses.py b/imcui/third_party/dust3r/dust3r/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8febff1a2dd674e759bcf83d023099a59cc934 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/losses.py @@ -0,0 +1,299 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Implementation of DUSt3R training losses +# -------------------------------------------------------- +from copy import copy, deepcopy +import torch +import torch.nn as nn + +from dust3r.inference import get_pred_pts3d, find_opt_scaling +from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud +from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale + + +def Sum(*losses_and_masks): + loss, mask = losses_and_masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + return losses_and_masks + else: + # we are returning the global loss + for loss2, mask2 in losses_and_masks[1:]: + loss = loss + loss2 + return loss + + +class BaseCriterion(nn.Module): + def __init__(self, reduction='mean'): + super().__init__() + self.reduction = reduction + + +class LLoss (BaseCriterion): + """ L-norm loss + """ + + def forward(self, a, b): + assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' + dist = self.distance(a, b) + assert dist.ndim == a.ndim - 1 # one dimension less + if self.reduction == 'none': + return dist + if self.reduction == 'sum': + return dist.sum() + if self.reduction == 'mean': + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f'bad {self.reduction=} mode') + + def distance(self, a, b): + raise NotImplementedError() + + +class L21Loss (LLoss): + """ Euclidean distance between 3d points """ + + def distance(self, a, b): + return torch.norm(a - b, dim=-1) # normalized L2 distance + + +L21 = L21Loss() + + +class Criterion (nn.Module): + def __init__(self, criterion=None): + super().__init__() + assert isinstance(criterion, BaseCriterion), f'{criterion} is not a proper criterion!' + self.criterion = copy(criterion) + + def get_name(self): + return f'{type(self).__name__}({self.criterion})' + + def with_reduction(self, mode='none'): + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = mode # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss (nn.Module): + """ Easily combinable losses (also keep track of individual loss values): + loss = MyLoss1() + 0.1*MyLoss2() + Usage: + Inherit from this class and override get_name() and compute_loss() + """ + + def __init__(self): + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + def __mul__(self, alpha): + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + __rmul__ = __mul__ # same + + def __add__(self, loss2): + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + # find the end of the chain + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + name = self.get_name() + if self._alpha != 1: + name = f'{self._alpha:g}*{name}' + if self._loss2: + name = f'{name} + {self._loss2}' + return name + + def forward(self, *args, **kwargs): + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class Regr3D (Criterion, MultiLoss): + """ Ensure that all 3D points are correct. + Asymmetric loss: view1 is supposed to be the anchor. + + P1 = RT1 @ D1 + P2 = RT2 @ D2 + loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) + loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) + = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) + """ + + def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False): + super().__init__(criterion) + self.norm_mode = norm_mode + self.gt_scale = gt_scale + + def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): + # everything is normalized w.r.t. camera of view1 + in_camera1 = inv(gt1['camera_pose']) + gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3 + gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3 + + valid1 = gt1['valid_mask'].clone() + valid2 = gt2['valid_mask'].clone() + + if dist_clip is not None: + # points that are too far-away == invalid + dis1 = gt_pts1.norm(dim=-1) # (B, H, W) + dis2 = gt_pts2.norm(dim=-1) # (B, H, W) + valid1 = valid1 & (dis1 <= dist_clip) + valid2 = valid2 & (dis2 <= dist_clip) + + pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) + pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True) + + # normalize 3d points + if self.norm_mode: + pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) + if self.norm_mode and not self.gt_scale: + gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) + + return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {} + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) + # loss on img1 side + l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) + # loss on gt2 side + l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2]) + self_name = type(self).__name__ + details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())} + return Sum((l1, mask1), (l2, mask2)), (details | monitoring) + + +class ConfLoss (MultiLoss): + """ Weighted regression by learned confidence. + Assuming the input pixel_loss is a pixel-level regression loss. + + Principle: + high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) + low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) + + alpha: hyperparameter + """ + + def __init__(self, pixel_loss, alpha=1): + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction('none') + + def get_name(self): + return f'ConfLoss({self.pixel_loss})' + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + # compute per-pixel loss + ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) + if loss1.numel() == 0: + print('NO VALID POINTS in img1', force=True) + if loss2.numel() == 0: + print('NO VALID POINTS in img2', force=True) + + # weight by confidence + conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) + conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) + conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 + conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 + + # average + nan protection (in case of no valid pixels at all) + conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 + conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 + + return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) + + +class Regr3D_ShiftInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute unnormalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # compute median depth + gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] + pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] + gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] + pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] + + # subtract the median depth + gt_z1 -= gt_shift_z + gt_z2 -= gt_shift_z + pred_z1 -= pred_shift_z + pred_z2 -= pred_shift_z + + # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach()) + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + if gt_scale == True: enforce the prediction to take the same scale than GT + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute depth-normalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # measure scene scale + _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) + _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) + + # prevent predictions to be in a ridiculous range + pred_scale = pred_scale.clip(min=1e-3, max=1e3) + + # subtract the median depth + if self.gt_scale: + pred_pts1 *= gt_scale / pred_scale + pred_pts2 *= gt_scale / pred_scale + # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean()) + else: + gt_pts1 /= gt_scale + gt_pts2 /= gt_scale + pred_pts1 /= pred_scale + pred_pts2 /= pred_scale + # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()) + + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): + # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv + pass diff --git a/imcui/third_party/dust3r/dust3r/model.py b/imcui/third_party/dust3r/dust3r/model.py new file mode 100644 index 0000000000000000000000000000000000000000..41c3a4f78eb5fbafdeb7ab8523468de320886c64 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/model.py @@ -0,0 +1,210 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R model class +# -------------------------------------------------------- +from copy import deepcopy +import torch +import os +from packaging import version +import huggingface_hub + +from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape +from .heads import head_factory +from dust3r.patch_embed import get_patch_embed + +import dust3r.utils.path_to_croco # noqa: F401 +from models.croco import CroCoNet # noqa + +inf = float('inf') + +hf_version_number = huggingface_hub.__version__ +assert version.parse(hf_version_number) >= version.parse("0.22.0"), ("Outdated huggingface_hub version, " + "please reinstall requirements.txt") + + +def load_model(model_path, device, verbose=True): + if verbose: + print('... loading model from', model_path) + ckpt = torch.load(model_path, map_location='cpu') + args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if 'landscape_only' not in args: + args = args[:-1] + ', landscape_only=False)' + else: + args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') + assert "landscape_only=False" in args + if verbose: + print(f"instantiating : {args}") + net = eval(args) + s = net.load_state_dict(ckpt['model'], strict=False) + if verbose: + print(s) + return net.to(device) + + +class AsymmetricCroCo3DStereo ( + CroCoNet, + huggingface_hub.PyTorchModelHubMixin, + library_name="dust3r", + repo_url="https://github.com/naver/dust3r", + tags=["image-to-3d"], +): + """ Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__(self, + output_mode='pts3d', + head_type='linear', + depth_mode=('exp', -inf, inf), + conf_mode=('exp', 1, inf), + freeze='none', + landscape_only=True, + patch_embed_cls='PatchEmbedDust3R', # PatchEmbedDust3R or ManyAR_PatchEmbed + **croco_kwargs): + self.patch_embed_cls = patch_embed_cls + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + + # dust3r specific initialization + self.dec_blocks2 = deepcopy(self.dec_blocks) + self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) + self.set_freeze(freeze) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device='cpu') + else: + try: + model = super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) + except TypeError as e: + raise Exception(f'tried to load {pretrained_model_name_or_path} from huggingface, but failed') + return model + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) + + def load_state_dict(self, ckpt, **kw): + # duplicate all weights for the second decoder if not present + new_ckpt = dict(ckpt) + if not any(k.startswith('dec_blocks2') for k in ckpt): + for key, value in ckpt.items(): + if key.startswith('dec_blocks'): + new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value + return super().load_state_dict(new_ckpt, **kw) + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + 'none': [], + 'mask': [self.mask_token], + 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], + } + freeze_all_params(to_be_frozen[freeze]) + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head """ + return + + def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, + **kw): + assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ + f'{img_size=} must be multiple of {patch_size=}' + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + # allocate heads + self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + # magic wrapper + self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) + self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) + + def _encode_image(self, image, true_shape): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + + # add positional embedding without cls token + assert self.enc_pos_embed is None + + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, pos) + + x = self.enc_norm(x) + return x, pos, None + + def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): + if img1.shape[-2:] == img2.shape[-2:]: + out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), + torch.cat((true_shape1, true_shape2), dim=0)) + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + else: + out, pos, _ = self._encode_image(img1, true_shape1) + out2, pos2, _ = self._encode_image(img2, true_shape2) + return out, out2, pos, pos2 + + def _encode_symmetrized(self, view1, view2): + img1 = view1['img'] + img2 = view2['img'] + B = img1.shape[0] + # Recover true_shape when available, otherwise assume that the img shape is the true one + shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) + shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) + # warning! maybe the images have different portrait/landscape orientations + + if is_symmetrized(view1, view2): + # computing half of forward pass!' + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2]) + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2): + final_output = [(f1, f2)] # before projection + + # project to decoder dim + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + final_output.append((f1, f2)) + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + # img1 side + f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) + # img2 side + f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) + # store the result + final_output.append((f1, f2)) + + # normalize last output + del final_output[1] # duplicate with final_output[0] + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + # img_shape = tuple(map(int, img_shape)) + head = getattr(self, f'head{head_num}') + return head(decout, img_shape) + + def forward(self, view1, view2): + # encode the two images --> B,S,D + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2) + + # combine all ref images into object-centric representation + dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) + + with torch.cuda.amp.autocast(enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2['pts3d_in_other_view'] = res2.pop('pts3d') # predict view2's pts3d in view1's frame + return res1, res2 diff --git a/imcui/third_party/dust3r/dust3r/optim_factory.py b/imcui/third_party/dust3r/dust3r/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9c16e0e0fda3fd03c3def61abc1f354f75c584 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/optim_factory.py @@ -0,0 +1,14 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# optimization functions +# -------------------------------------------------------- + + +def adjust_learning_rate_by_lr(optimizer, lr): + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr diff --git a/imcui/third_party/dust3r/dust3r/patch_embed.py b/imcui/third_party/dust3r/dust3r/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..07bb184bccb9d16657581576779904065d2dc857 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/patch_embed.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# PatchEmbed implementation for DUST3R, +# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio +# -------------------------------------------------------- +import torch +import dust3r.utils.path_to_croco # noqa: F401 +from models.blocks import PatchEmbed # noqa + + +def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): + assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] + patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) + return patch_embed + + +class PatchEmbedDust3R(PatchEmbed): + def forward(self, x, **kw): + B, C, H, W = x.shape + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + +class ManyAR_PatchEmbed (PatchEmbed): + """ Handle images with non-square aspect ratio. + All images in the same batch have the same aspect ratio. + true_shape = [(height, width) ...] indicates the actual shape of each image. + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + self.embed_dim = embed_dim + super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) + + def forward(self, img, true_shape): + B, C, H, W = img.shape + assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" + + # size expressed in tokens + W //= self.patch_size[0] + H //= self.patch_size[1] + n_tokens = H * W + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # allocate result + x = img.new_zeros((B, n_tokens, self.embed_dim)) + pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) + + # linear projection, transposed if necessary + x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() + x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() + + pos[is_landscape] = self.position_getter(1, H, W, pos.device) + pos[is_portrait] = self.position_getter(1, W, H, pos.device) + + x = self.norm(x) + return x, pos diff --git a/imcui/third_party/dust3r/dust3r/post_process.py b/imcui/third_party/dust3r/dust3r/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..550a9b41025ad003228ef16f97d045fc238746e4 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/post_process.py @@ -0,0 +1,60 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities for interpreting the DUST3R output +# -------------------------------------------------------- +import numpy as np +import torch +from dust3r.utils.geometry import xy_grid + + +def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf): + """ Reprojection method, for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + # centered pixel grid + pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 + pts3d = pts3d.flatten(1, 2) # (B, HW, 3) + + if focal_mode == 'median': + with torch.no_grad(): + # direct estimation of focal + u, v = pixels.unbind(dim=-1) + x, y, z = pts3d.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + # assume square pixels, hence same focal for X and Y + f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) + focal = torch.nanmedian(f_votes, dim=-1).values + + elif focal_mode == 'weiszfeld': + # init focal with l2 closed form + # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| + xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) + + dot_xy_px = (xy_over_z * pixels).sum(dim=-1) + dot_xy_xy = xy_over_z.square().sum(dim=-1) + + focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) + + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip(min=1e-8).reciprocal() + # update the scaling with the new weights + focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) + else: + raise ValueError(f'bad {focal_mode=}') + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) + # print(focal) + return focal diff --git a/imcui/third_party/dust3r/dust3r/training.py b/imcui/third_party/dust3r/dust3r/training.py new file mode 100644 index 0000000000000000000000000000000000000000..53af9764ebb03a0083c22294298ed674e9164edc --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/training.py @@ -0,0 +1,377 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# training code for DUSt3R +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time +import math +from collections import defaultdict +from pathlib import Path +from typing import Sized + +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model +from dust3r.datasets import get_data_loader # noqa +from dust3r.losses import * # noqa: F401, needed when loading the model +from dust3r.inference import loss_of_one_batch # noqa + +import dust3r.utils.path_to_croco # noqa: F401 +import croco.utils.misc as misc # noqa +from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa + + +def get_args_parser(): + parser = argparse.ArgumentParser('DUST3R training', add_help=False) + # model and criterion + parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')", + type=str, help="string containing the model to build") + parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint') + parser.add_argument('--train_criterion', default="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)", + type=str, help="train criterion") + parser.add_argument('--test_criterion', default=None, type=str, help="test criterion") + + # dataset + parser.add_argument('--train_dataset', required=True, type=str, help="training set") + parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set") + + # training + parser.add_argument('--seed', default=0, type=int, help="Random seed") + parser.add_argument('--batch_size', default=64, type=int, + help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") + parser.add_argument('--accum_iter', default=1, type=int, + help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") + parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") + + parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") + parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', + help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--min_lr', type=float, default=0., metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') + + parser.add_argument('--amp', type=int, default=0, + choices=[0, 1], help="Use Automatic Mixed Precision for pretraining") + parser.add_argument("--disable_cudnn_benchmark", action='store_true', default=False, + help="set cudnn.benchmark = False") + # others + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency') + parser.add_argument('--save_freq', default=1, type=int, + help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') + parser.add_argument('--keep_freq', default=20, type=int, + help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') + parser.add_argument('--print_freq', default=20, type=int, + help='frequence (number of iterations) to print infos while training') + + # output dir + parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output") + return parser + + +def train(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # auto resume + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = not args.disable_cudnn_benchmark + + # training dataset and loader + print('Building train dataset {:s}'.format(args.train_dataset)) + # dataset and loader + data_loader_train = build_dataset(args.train_dataset, args.batch_size, args.num_workers, test=False) + print('Building test dataset {:s}'.format(args.train_dataset)) + data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True) + for dataset in args.test_dataset.split('+')} + + # model + print('Loading model: {:s}'.format(args.model)) + model = eval(args.model) + print(f'>> Creating train criterion = {args.train_criterion}') + train_criterion = eval(args.train_criterion).to(device) + print(f'>> Creating test criterion = {args.test_criterion or args.train_criterion}') + test_criterion = eval(args.test_criterion or args.criterion).to(device) + + model.to(device) + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + if args.pretrained and not args.resume: + print('Loading pretrained: ', args.pretrained) + ckpt = torch.load(args.pretrained, map_location=device) + print(model.load_state_dict(ckpt['model'], strict=False)) + del ckpt # in case it occupies memory + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True) + model_without_ddp = model.module + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + def write_log_stats(epoch, train_stats, test_stats): + if misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + + log_stats = dict(epoch=epoch, **{f'train_{k}': v for k, v in train_stats.items()}) + for test_name in data_loader_test: + if test_name not in test_stats: + continue + log_stats.update({test_name + '_' + k: v for k, v in test_stats[test_name].items()}) + + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + def save_model(epoch, fname, best_so_far): + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, fname=fname, best_so_far=best_so_far) + + best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp, + optimizer=optimizer, loss_scaler=loss_scaler) + if best_so_far is None: + best_so_far = float('inf') + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + for epoch in range(args.start_epoch, args.epochs + 1): + + # Save immediately the last checkpoint + if epoch > args.start_epoch: + if args.save_freq and epoch % args.save_freq == 0 or epoch == args.epochs: + save_model(epoch - 1, 'last', best_so_far) + + # Test on multiple datasets + new_best = False + if (epoch > 0 and args.eval_freq > 0 and epoch % args.eval_freq == 0): + test_stats = {} + for test_name, testset in data_loader_test.items(): + stats = test_one_epoch(model, test_criterion, testset, + device, epoch, log_writer=log_writer, args=args, prefix=test_name) + test_stats[test_name] = stats + + # Save best of all + if stats['loss_med'] < best_so_far: + best_so_far = stats['loss_med'] + new_best = True + + # Save more stuff + write_log_stats(epoch, train_stats, test_stats) + + if epoch > args.start_epoch: + if args.keep_freq and epoch % args.keep_freq == 0: + save_model(epoch - 1, str(epoch), best_so_far) + if new_best: + save_model(epoch - 1, 'best', best_so_far) + if epoch >= args.epochs: + break # exit after writing last test to disk + + # Train + train_stats = train_one_epoch( + model, train_criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + log_writer=log_writer, + args=args) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + save_final_model(args, args.epochs, model_without_ddp, best_so_far=best_so_far) + + +def save_final_model(args, epoch, model_without_ddp, best_so_far=None): + output_dir = Path(args.output_dir) + checkpoint_path = output_dir / 'checkpoint-final.pth' + to_save = { + 'args': args, + 'model': model_without_ddp if isinstance(model_without_ddp, dict) else model_without_ddp.cpu().state_dict(), + 'epoch': epoch + } + if best_so_far is not None: + to_save['best_so_far'] = best_so_far + print(f'>> Saving model to {checkpoint_path} ...') + misc.save_on_master(to_save, checkpoint_path) + + +def build_dataset(dataset, batch_size, num_workers, test=False): + split = ['Train', 'Test'][test] + print(f'Building {split} Data loader for dataset: ', dataset) + loader = get_data_loader(dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=not (test), + drop_last=not (test)) + + print(f"{split} dataset length: ", len(loader)) + return loader + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Sized, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + args, + log_writer=None): + assert torch.backends.cuda.matmul.allow_tf32 == True + + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + accum_iter = args.accum_iter + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): + data_loader.sampler.set_epoch(epoch) + + optimizer.zero_grad() + + for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + epoch_f = epoch + data_iter_step / len(data_loader) + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, epoch_f, args) + + loss_tuple = loss_of_one_batch(batch, model, criterion, device, + symmetrize_batch=True, + use_amp=bool(args.amp), ret='loss') + loss, loss_details = loss_tuple # criterion returns two values + loss_value = float(loss) + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value), force=True) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + del loss + del batch + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(epoch=epoch_f) + metric_logger.update(lr=lr) + metric_logger.update(loss=loss_value, **loss_details) + + if (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0: + loss_value_reduce = misc.all_reduce_mean(loss_value) # MUST BE EXECUTED BY ALL NODES + if log_writer is None: + continue + """ We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int(epoch_f * 1000) + log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('train_lr', lr, epoch_1000x) + log_writer.add_scalar('train_iter', epoch_1000x, epoch_1000x) + for name, val in loss_details.items(): + log_writer.add_scalar('train_' + name, val, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Sized, device: torch.device, epoch: int, + args, log_writer=None, prefix='test'): + + model.eval() + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9)) + header = 'Test Epoch: [{}]'.format(epoch) + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): + data_loader.sampler.set_epoch(epoch) + + for _, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + loss_tuple = loss_of_one_batch(batch, model, criterion, device, + symmetrize_batch=True, + use_amp=bool(args.amp), ret='loss') + loss_value, loss_details = loss_tuple # criterion returns two values + metric_logger.update(loss=float(loss_value), **loss_details) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + aggs = [('avg', 'global_avg'), ('med', 'median')] + results = {f'{k}_{tag}': getattr(meter, attr) for k, meter in metric_logger.meters.items() for tag, attr in aggs} + + if log_writer is not None: + for name, val in results.items(): + log_writer.add_scalar(prefix + '_' + name, val, 1000 * epoch) + + return results diff --git a/imcui/third_party/dust3r/dust3r/utils/__init__.py b/imcui/third_party/dust3r/dust3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/dust3r/dust3r/utils/device.py b/imcui/third_party/dust3r/dust3r/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b6a74dac05a2e1ba3a2b2f0faa8cea08ece745 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/utils/device.py @@ -0,0 +1,76 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + ''' + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == 'numpy': + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): return todevice(x, 'numpy') +def to_cpu(x): return todevice(x, 'cpu') +def to_cuda(x): return todevice(x, 'cuda') + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) + + # otherwise, we just chain lists + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/imcui/third_party/dust3r/dust3r/utils/geometry.py b/imcui/third_party/dust3r/dust3r/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..ce365faf2acb97ffaafa1b80cb8ee0c28de0b6d6 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/utils/geometry.py @@ -0,0 +1,366 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# geometry utilitary functions +# -------------------------------------------------------- +import torch +import numpy as np +from scipy.spatial import cKDTree as KDTree + +from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans +from dust3r.utils.device import to_numpy + + +def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): + """ Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing='xy') + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = (depthmap > 0.0) + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + + return X_world, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None, ret_factor=False): + """ renorm pointmaps pts1, pts2 with norm_mode + """ + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split('_') + + if norm_mode == 'avg': + # gather all points together (joint normalization) + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == 'dis': + pass # do nothing + elif dis_mode == 'log1p': + all_dis = torch.log1p(all_dis) + elif dis_mode == 'warp-log1p': + # actually warp input points before normalizing them + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, :W1 * H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1 * H1:].view(-1, H2, W2, 1) + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f'bad {dis_mode=}') + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + # gather all points together (joint normalization) + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + + if norm_mode == 'avg': + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == 'median': + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == 'sqrt': + norm_factor = all_dis.sqrt().nanmean(dim=1)**2 + else: + raise ValueError(f'bad {norm_mode=}') + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + if ret_factor: + res = res + (norm_factor,) + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + # set invalid points to NaN + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True): + # set invalid points to NaN + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))) + reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) diff --git a/imcui/third_party/dust3r/dust3r/utils/image.py b/imcui/third_party/dust3r/dust3r/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6312a346df919ae6a0424504d824ef813fea250f --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/utils/image.py @@ -0,0 +1,126 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions about images (loading/converting...) +# -------------------------------------------------------- +import os +import torch +import numpy as np +import PIL.Image +from PIL.ImageOps import exif_transpose +import torchvision.transforms as tvf +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa + +try: + from pillow_heif import register_heif_opener # noqa + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +def img_to_arr( img ): + if isinstance(img, str): + img = imread_cv2(img) + return img + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """ Open an image or a depthmap with opencv-python. + """ + if path.endswith(('.exr', 'EXR')): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f'Could not load image={path} with {options=}') + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) + return img.resize(new_size, interp) + + +def load_images(folder_or_list, size, square_ok=False, verbose=True): + """ open and convert all images in a list or folder to proper input format for DUSt3R + """ + if isinstance(folder_or_list, str): + if verbose: + print(f'>> Loading images from {folder_or_list}') + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f'>> Loading a list of {len(folder_or_list)} images') + root, folder_content = '', folder_or_list + + else: + raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') + + supported_images_extensions = ['.jpg', '.jpeg', '.png'] + if heif_support_enabled: + supported_images_extensions += ['.heic', '.heif'] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB') + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1))) + else: + # resize long side to 512 + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W//2, H//2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx-half, cy-half, cx+half, cy+half)) + else: + halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 + if not (square_ok) and W == H: + halfh = 3*halfw/4 + img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) + + W2, H2 = img.size + if verbose: + print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') + imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( + [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)))) + + assert imgs, 'no images foud at '+root + if verbose: + print(f' (Found {len(imgs)} images)') + return imgs diff --git a/imcui/third_party/dust3r/dust3r/utils/misc.py b/imcui/third_party/dust3r/dust3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..88c4d2dab6d5c14021ed9ed6646c3159a3a4637b --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/utils/misc.py @@ -0,0 +1,121 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import torch + + +def fill_default_args(kwargs, func): + import inspect # a bit hacky but it works reliably + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + # module is directly a parameter + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1['instance'] + y = gt2['instance'] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def flip(tensor): + """ flip so that tensor[0::2] <=> tensor[1::2] """ + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """ Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + def wrapper_no(decout, true_shape): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W)) + return res + + def wrapper_yes(decout, true_shape): + B = len(true_shape) + # by definition, the batch is in landscape mode so W >= H + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # true_shape = true_shape.cpu() + if is_landscape.all(): + return head(decout, (H, W)) + if is_portrait.all(): + return transposed(head(decout, (W, H))) + + # batch is a mix of both portraint & landscape + def selout(ar): return [d[ar] for d in decout] + l_result = head(selout(is_landscape), (H, W)) + p_result = transposed(head(selout(is_portrait), (W, H))) + + # allocate full result + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float('nan') + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/imcui/third_party/dust3r/dust3r/utils/parallel.py b/imcui/third_party/dust3r/dust3r/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..06ae7fefdb9d2298929f0cbc20dfbc57eb7d7f7b --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/utils/parallel.py @@ -0,0 +1,79 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for multiprocessing +# -------------------------------------------------------- +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import cpu_count + + +def parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw): + """ tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + while workers <= 0: + workers += cpu_count() + if workers == 1: + front_num = float('inf') + + # convert into an iterable + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + # sequential execution first + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # end of the iterable + front.append(function(*a) if star_args else function(**a) if kw_args else function(a)) + + # then parallel execution + out = [] + with Pool(workers) as pool: + # Pass the elements of args into function + if star_args: + futures = pool.imap(starcall, [(function, a) for a in args]) + elif kw_args: + futures = pool.imap(starstarcall, [(function, a) for a in args]) + else: + futures = pool.imap(function, args) + # Print out the progress as tasks complete + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """ Same as parallel_threads, with processes + """ + import multiprocessing as mp + kwargs['Pool'] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(*args) + + +def starstarcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(**args) diff --git a/imcui/third_party/dust3r/dust3r/utils/path_to_croco.py b/imcui/third_party/dust3r/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000000000000000000000000000000000000..39226ce6bc0e1993ba98a22096de32cb6fa916b4 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/utils/path_to_croco.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# CroCo submodule import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco')) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models') +# check the presence of models directory in repo to be sure its cloned +if path.isdir(CROCO_MODELS_PATH): + # workaround for sibling import + sys.path.insert(0, CROCO_REPO_PATH) +else: + raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?") diff --git a/imcui/third_party/dust3r/dust3r/viz.py b/imcui/third_party/dust3r/dust3r/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..9150e8b850d9f1e6bf9ddf6e865d34fc743e276a --- /dev/null +++ b/imcui/third_party/dust3r/dust3r/viz.py @@ -0,0 +1,381 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Visualization utilities using trimesh +# -------------------------------------------------------- +import PIL.Image +import numpy as np +from scipy.spatial.transform import Rotation +import torch + +from dust3r.utils.geometry import geotrf, get_med_dist_between_poses, depthmap_to_absolute_camera_coordinates +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb, img_to_arr + +try: + import trimesh +except ImportError: + print('/!\\ module trimesh is not installed, cannot visualize results /!\\') + + + +def cat_3d(vecs): + if isinstance(vecs, (np.ndarray, torch.Tensor)): + vecs = [vecs] + return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)]) + + +def show_raw_pointcloud(pts3d, colors, point_size=2): + scene = trimesh.Scene() + + pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors)) + scene.add_geometry(pct) + + scene.show(line_settings={'point_size': point_size}) + + +def pts3d_to_trimesh(img, pts3d, valid=None): + H, W, THREE = img.shape + assert THREE == 3 + assert img.shape == pts3d.shape + + vertices = pts3d.reshape(-1, 3) + + # make squares: each pixel == 2 triangles + idx = np.arange(len(vertices)).reshape(H, W) + idx1 = idx[:-1, :-1].ravel() # top-left corner + idx2 = idx[:-1, +1:].ravel() # right-left corner + idx3 = idx[+1:, :-1].ravel() # bottom-left corner + idx4 = idx[+1:, +1:].ravel() # bottom-right corner + faces = np.concatenate(( + np.c_[idx1, idx2, idx3], + np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling) + np.c_[idx2, idx3, idx4], + np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling) + ), axis=0) + + # prepare triangle colors + face_colors = np.concatenate(( + img[:-1, :-1].reshape(-1, 3), + img[:-1, :-1].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3) + ), axis=0) + + # remove invalid faces + if valid is not None: + assert valid.shape == (H, W) + valid_idxs = valid.ravel() + valid_faces = valid_idxs[faces].all(axis=-1) + faces = faces[valid_faces] + face_colors = face_colors[valid_faces] + + assert len(faces) == len(face_colors) + return dict(vertices=vertices, face_colors=face_colors, faces=faces) + + +def cat_meshes(meshes): + vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes]) + n_vertices = np.cumsum([0]+[len(v) for v in vertices]) + for i in range(len(faces)): + faces[i][:] += n_vertices[i] + + vertices = np.concatenate(vertices) + colors = np.concatenate(colors) + faces = np.concatenate(faces) + return dict(vertices=vertices, face_colors=colors, faces=faces) + + +def show_duster_pairs(view1, view2, pred1, pred2): + import matplotlib.pyplot as pl + pl.ion() + + for e in range(len(view1['instance'])): + i = view1['idx'][e] + j = view2['idx'][e] + img1 = rgb(view1['img'][e]) + img2 = rgb(view2['img'][e]) + conf1 = pred1['conf'][e].squeeze() + conf2 = pred2['conf'][e].squeeze() + score = conf1.mean()*conf2.mean() + print(f">> Showing pair #{e} {i}-{j} {score=:g}") + pl.clf() + pl.subplot(221).imshow(img1) + pl.subplot(223).imshow(img2) + pl.subplot(222).imshow(conf1, vmin=1, vmax=30) + pl.subplot(224).imshow(conf2, vmin=1, vmax=30) + pts1 = pred1['pts3d'][e] + pts2 = pred2['pts3d_in_other_view'][e] + pl.subplots_adjust(0, 0, 1, 1, 0, 0) + if input('show pointcloud? (y/n) ') == 'y': + show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5) + + +def auto_cam_size(im_poses): + return 0.1 * get_med_dist_between_poses(im_poses) + + +class SceneViz: + def __init__(self): + self.scene = trimesh.Scene() + + def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None): + image = img_to_arr(image) + + # make up some intrinsics + if intrinsics is None: + H, W, THREE = image.shape + focal = max(H, W) + intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]]) + + # compute 3d points + pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world) + + return self.add_pointcloud(pts3d, image, mask=(depth 150) + mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) + mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) + + # Morphological operations + kernel = np.ones((5, 5), np.uint8) + mask2 = ndimage.binary_opening(mask, structure=kernel) + + # keep only largest CC + _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8) + cc_sizes = stats[1:, cv2.CC_STAT_AREA] + order = cc_sizes.argsort()[::-1] # bigger first + i = 0 + selection = [] + while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: + selection.append(1 + order[i]) + i += 1 + mask3 = np.in1d(labels, selection).reshape(labels.shape) + + # Apply mask + return torch.from_numpy(mask3) diff --git a/imcui/third_party/dust3r/dust3r_visloc/__init__.py b/imcui/third_party/dust3r/dust3r_visloc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/__init__.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..566926b1e248e4b64fc5182031af634435bb8601 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/__init__.py @@ -0,0 +1,6 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +from .sevenscenes import VislocSevenScenes +from .cambridge_landmarks import VislocCambridgeLandmarks +from .aachen_day_night import VislocAachenDayNight +from .inloc import VislocInLoc diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/aachen_day_night.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/aachen_day_night.py new file mode 100644 index 0000000000000000000000000000000000000000..159548e8b51a1b5872a2392cd9107ff96e40e801 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/aachen_day_night.py @@ -0,0 +1,24 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# AachenDayNight dataloader +# -------------------------------------------------------- +import os +from dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset + + +class VislocAachenDayNight(BaseVislocColmapDataset): + def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False): + assert subscene in [None, '', 'day', 'night', 'all'] + self.subscene = subscene + image_path = os.path.join(root, 'images') + map_path = os.path.join(root, 'mapping/colmap/reconstruction') + query_path = os.path.join(root, 'kapture', 'query') + pairsfile_path = os.path.join(root, 'pairsfile/query', pairsfile + '.txt') + super().__init__(image_path=image_path, map_path=map_path, + query_path=query_path, pairsfile_path=pairsfile_path, + topk=topk, cache_sfm=cache_sfm) + self.scenes = [filename for filename in self.scenes if filename in self.pairs] + if self.subscene == 'day' or self.subscene == 'night': + self.scenes = [filename for filename in self.scenes if self.subscene in filename] diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/base_colmap.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/base_colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..def1da61b5d3b416db5845c2016082348df944a6 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/base_colmap.py @@ -0,0 +1,282 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class for colmap / kapture +# -------------------------------------------------------- +import os +import numpy as np +from tqdm import tqdm +import collections +import pickle +import PIL.Image +import torch +from scipy.spatial.transform import Rotation +import torchvision.transforms as tvf + +from kapture.core import CameraType +from kapture.io.csv import kapture_from_dir +from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file + +from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d +from dust3r_visloc.datasets.base_dataset import BaseVislocDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import colmap_to_opencv_intrinsics + +KaptureSensor = collections.namedtuple('Sensor', 'sensor_params camera_params') + + +def kapture_to_opencv_intrinsics(sensor): + """ + Convert from Kapture to OpenCV parameters. + Warning: we assume that the camera and pixel coordinates follow Colmap conventions here. + Args: + sensor: Kapture sensor + """ + sensor_type = sensor.sensor_params[0] + if sensor_type == "SIMPLE_PINHOLE": + # Simple pinhole model. + # We still call OpenCV undistorsion however for code simplicity. + w, h, f, cx, cy = sensor.camera_params + k1 = 0 + k2 = 0 + p1 = 0 + p2 = 0 + fx = fy = f + elif sensor_type == "PINHOLE": + w, h, fx, fy, cx, cy = sensor.camera_params + k1 = 0 + k2 = 0 + p1 = 0 + p2 = 0 + elif sensor_type == "SIMPLE_RADIAL": + w, h, f, cx, cy, k1 = sensor.camera_params + k2 = 0 + p1 = 0 + p2 = 0 + fx = fy = f + elif sensor_type == "RADIAL": + w, h, f, cx, cy, k1, k2 = sensor.camera_params + p1 = 0 + p2 = 0 + fx = fy = f + elif sensor_type == "OPENCV": + w, h, fx, fy, cx, cy, k1, k2, p1, p2 = sensor.camera_params + else: + raise NotImplementedError(f"Sensor type {sensor_type} is not supported yet.") + + cameraMatrix = np.asarray([[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]], dtype=np.float32) + + # We assume that Kapture data comes from Colmap: the origin is different. + cameraMatrix = colmap_to_opencv_intrinsics(cameraMatrix) + + distCoeffs = np.asarray([k1, k2, p1, p2], dtype=np.float32) + return cameraMatrix, distCoeffs, (w, h) + + +def K_from_colmap(elems): + sensor = KaptureSensor(elems, tuple(map(float, elems[1:]))) + cameraMatrix, distCoeffs, (w, h) = kapture_to_opencv_intrinsics(sensor) + res = dict(resolution=(w, h), + intrinsics=cameraMatrix, + distortion=distCoeffs) + return res + + +def pose_from_qwxyz_txyz(elems): + qw, qx, qy, qz, tx, ty, tz = map(float, elems) + pose = np.eye(4) + pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() + pose[:3, 3] = (tx, ty, tz) + return np.linalg.inv(pose) # returns cam2world + + +class BaseVislocColmapDataset(BaseVislocDataset): + def __init__(self, image_path, map_path, query_path, pairsfile_path, topk=1, cache_sfm=False): + super().__init__() + self.topk = topk + self.num_views = self.topk + 1 + self.image_path = image_path + self.cache_sfm = cache_sfm + + self._load_sfm(map_path) + + kdata_query = kapture_from_dir(query_path) + assert kdata_query.records_camera is not None and kdata_query.trajectories is not None + + kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} + self.query_data = {'kdata': kdata_query, 'searchindex': kdata_query_searchindex} + + self.pairs = get_ordered_pairs_from_file(pairsfile_path) + self.scenes = kdata_query.records_camera.data_list() + + def _load_sfm(self, sfm_dir): + sfm_cache_path = os.path.join(sfm_dir, 'dust3r_cache.pkl') + if os.path.isfile(sfm_cache_path) and self.cache_sfm: + with open(sfm_cache_path, "rb") as f: + data = pickle.load(f) + self.img_infos = data['img_infos'] + self.points3D = data['points3D'] + return + + # load cameras + with open(os.path.join(sfm_dir, 'cameras.txt'), 'r') as f: + raw = f.read().splitlines()[3:] # skip header + + intrinsics = {} + for camera in tqdm(raw): + camera = camera.split(' ') + intrinsics[int(camera[0])] = K_from_colmap(camera[1:]) + + # load images + with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + self.img_infos = {} + for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2): + image = image.split(' ') + points = points.split(' ') + + img_name = image[-1] + current_points2D = {int(i): (float(x), float(y)) + for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} + self.img_infos[img_name] = dict(intrinsics[int(image[-2])], + path=img_name, + camera_pose=pose_from_qwxyz_txyz(image[1: -2]), + sparse_pts2d=current_points2D) + + # load 3D points + with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + self.points3D = {} + for point in tqdm(raw): + point = point.split() + self.points3D[int(point[0])] = tuple(map(float, point[1:4])) + + if self.cache_sfm: + to_save = \ + { + 'img_infos': self.img_infos, + 'points3D': self.points3D + } + with open(sfm_cache_path, "wb") as f: + pickle.dump(to_save, f) + + def __len__(self): + return len(self.scenes) + + def _get_view_query(self, imgname): + kdata, searchindex = map(self.query_data.get, ['kdata', 'searchindex']) + + timestamp, camera_id = searchindex[imgname] + + camera_params = kdata.sensors[camera_id].camera_params + if kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_PINHOLE: + W, H, f, cx, cy = camera_params + k1 = 0 + fx = fy = f + elif kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_RADIAL: + W, H, f, cx, cy, k1 = camera_params + fx = fy = f + else: + raise NotImplementedError('not implemented') + + W, H = int(W), int(H) + intrinsics = np.float32([(fx, 0, cx), + (0, fy, cy), + (0, 0, 1)]) + intrinsics = colmap_to_opencv_intrinsics(intrinsics) + distortion = [k1, 0, 0, 0] + + if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: + cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) + else: + cam_to_world = np.eye(4, dtype=np.float32) + + # Load RGB image + rgb_image = PIL.Image.open(os.path.join(self.image_path, imgname)).convert('RGB') + rgb_image.load() + resize_func, _, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + 'rgb_rescaled': rgb_tensor, + 'to_orig': to_orig, + 'idx': 0, + 'image_name': imgname + } + return view + + def _get_view_map(self, imgname, idx): + infos = self.img_infos[imgname] + + rgb_image = PIL.Image.open(os.path.join(self.image_path, infos['path'])).convert('RGB') + rgb_image.load() + W, H = rgb_image.size + intrinsics = infos['intrinsics'] + intrinsics = colmap_to_opencv_intrinsics(intrinsics) + distortion_coefs = infos['distortion'] + + pts2d = infos['sparse_pts2d'] + sparse_pos2d = np.float32(list(pts2d.values())).reshape((-1, 2)) # pts2d from colmap + sparse_pts3d = np.float32([self.points3D[i] for i in pts2d]).reshape((-1, 3)) + + # store full resolution 2D->3D + sparse_pos2d_cv2 = sparse_pos2d.copy() + sparse_pos2d_cv2[:, 0] -= 0.5 + sparse_pos2d_cv2[:, 1] -= 0.5 + sparse_pos2d_int = sparse_pos2d_cv2.round().astype(np.int64) + valid = (sparse_pos2d_int[:, 0] >= 0) & (sparse_pos2d_int[:, 0] < W) & ( + sparse_pos2d_int[:, 1] >= 0) & (sparse_pos2d_int[:, 1] < H) + sparse_pos2d_int = sparse_pos2d_int[valid] + # nan => invalid + pts3d = np.full((H, W, 3), np.nan, dtype=np.float32) + pts3d[sparse_pos2d_int[:, 1], sparse_pos2d_int[:, 0]] = sparse_pts3d[valid] + pts3d = torch.from_numpy(pts3d) + + cam_to_world = infos['camera_pose'] # cam2world + + # also store resized resolution 2D->3D + resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + HR, WR = rgb_tensor.shape[1:] + _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(sparse_pos2d_cv2, sparse_pts3d, to_resize, HR, WR) + pts3d_rescaled = torch.from_numpy(pts3d_rescaled) + valid_rescaled = torch.from_numpy(valid_rescaled) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion_coefs, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + "pts3d": pts3d, + "valid": pts3d.sum(dim=-1).isfinite(), + 'rgb_rescaled': rgb_tensor, + "pts3d_rescaled": pts3d_rescaled, + "valid_rescaled": valid_rescaled, + 'to_orig': to_orig, + 'idx': idx, + 'image_name': imgname + } + return view + + def __getitem__(self, idx): + assert self.maxdim is not None and self.patch_size is not None + query_image = self.scenes[idx] + map_images = [p[0] for p in self.pairs[query_image][:self.topk]] + views = [] + views.append(self._get_view_query(query_image)) + for idx, map_image in enumerate(map_images): + views.append(self._get_view_map(map_image, idx + 1)) + return views diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/base_dataset.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cda3774c5ab5b668be5eecf89681abc96df5fe17 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/base_dataset.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class +# -------------------------------------------------------- +class BaseVislocDataset: + def __init__(self): + pass + + def set_resolution(self, model): + self.maxdim = max(model.patch_embed.img_size) + self.patch_size = model.patch_embed.patch_size + + def __len__(self): + raise NotImplementedError() + + def __getitem__(self, idx): + raise NotImplementedError() \ No newline at end of file diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3e131941bf444d86a709d23e518e7b93d3d0f6 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Cambridge Landmarks dataloader +# -------------------------------------------------------- +import os +from dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset + + +class VislocCambridgeLandmarks (BaseVislocColmapDataset): + def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False): + image_path = os.path.join(root, subscene) + map_path = os.path.join(root, 'mapping', subscene, 'colmap/reconstruction') + query_path = os.path.join(root, 'kapture', subscene, 'query') + pairsfile_path = os.path.join(root, subscene, 'pairsfile/query', pairsfile + '.txt') + super().__init__(image_path=image_path, map_path=map_path, + query_path=query_path, pairsfile_path=pairsfile_path, + topk=topk, cache_sfm=cache_sfm) \ No newline at end of file diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/inloc.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/inloc.py new file mode 100644 index 0000000000000000000000000000000000000000..99ed11f554203d353d0559d0589f40ec1ffbf66e --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/inloc.py @@ -0,0 +1,167 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# InLoc dataloader +# -------------------------------------------------------- +import os +import numpy as np +import torch +import PIL.Image +import scipy.io + +import kapture +from kapture.io.csv import kapture_from_dir +from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file + +from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d +from dust3r_visloc.datasets.base_dataset import BaseVislocDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import xy_grid, geotrf + + +def read_alignments(path_to_alignment): + aligns = {} + with open(path_to_alignment, "r") as fid: + while True: + line = fid.readline() + if not line: + break + if len(line) == 4: + trans_nr = line[:-1] + while line != 'After general icp:\n': + line = fid.readline() + line = fid.readline() + p = [] + for i in range(4): + elems = line.split(' ') + line = fid.readline() + for e in elems: + if len(e) != 0: + p.append(float(e)) + P = np.array(p).reshape(4, 4) + aligns[trans_nr] = P + return aligns + + +class VislocInLoc(BaseVislocDataset): + def __init__(self, root, pairsfile, topk=1): + super().__init__() + self.root = root + self.topk = topk + self.num_views = self.topk + 1 + self.maxdim = None + self.patch_size = None + + query_path = os.path.join(self.root, 'query') + kdata_query = kapture_from_dir(query_path) + assert kdata_query.records_camera is not None + kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} + self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex} + + map_path = os.path.join(self.root, 'mapping') + kdata_map = kapture_from_dir(map_path) + assert kdata_map.records_camera is not None and kdata_map.trajectories is not None + kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_map.records_camera.key_pairs()} + self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex} + + try: + self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt')) + except Exception as e: + # if using pairs from hloc + self.pairs = {} + with open(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt'), 'r') as fid: + lines = fid.readlines() + for line in lines: + splits = line.rstrip("\n\r").split(" ") + self.pairs.setdefault(splits[0].replace('query/', ''), []).append( + (splits[1].replace('database/cutouts/', ''), 1.0) + ) + + self.scenes = kdata_query.records_camera.data_list() + + self.aligns_DUC1 = read_alignments(os.path.join(self.root, 'mapping/DUC1_alignment/all_transformations.txt')) + self.aligns_DUC2 = read_alignments(os.path.join(self.root, 'mapping/DUC2_alignment/all_transformations.txt')) + + def __len__(self): + return len(self.scenes) + + def __getitem__(self, idx): + assert self.maxdim is not None and self.patch_size is not None + query_image = self.scenes[idx] + map_images = [p[0] for p in self.pairs[query_image][:self.topk]] + views = [] + dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True) + for map_image in map_images] + for idx, (imgname, data, should_load_depth) in enumerate(dataarray): + imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex']) + + timestamp, camera_id = searchindex[imgname] + + # for InLoc, SIMPLE_PINHOLE + camera_params = kdata.sensors[camera_id].camera_params + W, H, f, cx, cy = camera_params + distortion = [0, 0, 0, 0] + intrinsics = np.float32([(f, 0, cx), + (0, f, cy), + (0, 0, 1)]) + + if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: + cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) + else: + cam_to_world = np.eye(4, dtype=np.float32) + + # Load RGB image + rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB') + rgb_image.load() + + W, H = rgb_image.size + resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + 'rgb_rescaled': rgb_tensor, + 'to_orig': to_orig, + 'idx': idx, + 'image_name': imgname + } + + # Load depthmap + if should_load_depth: + depthmap_filename = os.path.join(imgpath, 'sensors/records_data', imgname + '.mat') + depthmap = scipy.io.loadmat(depthmap_filename) + + pt3d_cut = depthmap['XYZcut'] + scene_id = imgname.replace('\\', '/').split('/')[1] + if imgname.startswith('DUC1'): + pts3d_full = geotrf(self.aligns_DUC1[scene_id], pt3d_cut) + else: + pts3d_full = geotrf(self.aligns_DUC2[scene_id], pt3d_cut) + + pts3d_valid = np.isfinite(pts3d_full.sum(axis=-1)) + + pts3d = pts3d_full[pts3d_valid] + pts2d_int = xy_grid(W, H)[pts3d_valid] + pts2d = pts2d_int.astype(np.float64) + + # nan => invalid + pts3d_full[~pts3d_valid] = np.nan + pts3d_full = torch.from_numpy(pts3d_full) + view['pts3d'] = pts3d_full + view["valid"] = pts3d_full.sum(dim=-1).isfinite() + + HR, WR = rgb_tensor.shape[1:] + _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR) + pts3d_rescaled = torch.from_numpy(pts3d_rescaled) + valid_rescaled = torch.from_numpy(valid_rescaled) + view['pts3d_rescaled'] = pts3d_rescaled + view["valid_rescaled"] = valid_rescaled + views.append(view) + return views diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/sevenscenes.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/sevenscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..c15e851d262f0d7ba7071c933d8fe8f0a6b1c49d --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/sevenscenes.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# 7 Scenes dataloader +# -------------------------------------------------------- +import os +import numpy as np +import torch +import PIL.Image + +import kapture +from kapture.io.csv import kapture_from_dir +from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file +from kapture.io.records import depth_map_from_file + +from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d +from dust3r_visloc.datasets.base_dataset import BaseVislocDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, xy_grid, geotrf + + +class VislocSevenScenes(BaseVislocDataset): + def __init__(self, root, subscene, pairsfile, topk=1): + super().__init__() + self.root = root + self.subscene = subscene + self.topk = topk + self.num_views = self.topk + 1 + self.maxdim = None + self.patch_size = None + + query_path = os.path.join(self.root, subscene, 'query') + kdata_query = kapture_from_dir(query_path) + assert kdata_query.records_camera is not None and kdata_query.trajectories is not None and kdata_query.rigs is not None + kapture.rigs_remove_inplace(kdata_query.trajectories, kdata_query.rigs) + kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} + self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex} + + map_path = os.path.join(self.root, subscene, 'mapping') + kdata_map = kapture_from_dir(map_path) + assert kdata_map.records_camera is not None and kdata_map.trajectories is not None and kdata_map.rigs is not None + kapture.rigs_remove_inplace(kdata_map.trajectories, kdata_map.rigs) + kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_map.records_camera.key_pairs()} + self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex} + + self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, subscene, + 'pairfiles/query', + pairsfile + '.txt')) + self.scenes = kdata_query.records_camera.data_list() + + def __len__(self): + return len(self.scenes) + + def __getitem__(self, idx): + assert self.maxdim is not None and self.patch_size is not None + query_image = self.scenes[idx] + map_images = [p[0] for p in self.pairs[query_image][:self.topk]] + views = [] + dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True) + for map_image in map_images] + for idx, (imgname, data, should_load_depth) in enumerate(dataarray): + imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex']) + + timestamp, camera_id = searchindex[imgname] + + # for 7scenes, SIMPLE_PINHOLE + camera_params = kdata.sensors[camera_id].camera_params + W, H, f, cx, cy = camera_params + distortion = [0, 0, 0, 0] + intrinsics = np.float32([(f, 0, cx), + (0, f, cy), + (0, 0, 1)]) + + cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) + + # Load RGB image + rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB') + rgb_image.load() + + W, H = rgb_image.size + resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + 'rgb_rescaled': rgb_tensor, + 'to_orig': to_orig, + 'idx': idx, + 'image_name': imgname + } + + # Load depthmap + if should_load_depth: + depthmap_filename = os.path.join(imgpath, 'sensors/records_data', + imgname.replace('color.png', 'depth.reg')) + depthmap = depth_map_from_file(depthmap_filename, (int(W), int(H))).astype(np.float32) + pts3d_full, pts3d_valid = depthmap_to_absolute_camera_coordinates(depthmap, intrinsics, cam_to_world) + + pts3d = pts3d_full[pts3d_valid] + pts2d_int = xy_grid(W, H)[pts3d_valid] + pts2d = pts2d_int.astype(np.float64) + + # nan => invalid + pts3d_full[~pts3d_valid] = np.nan + pts3d_full = torch.from_numpy(pts3d_full) + view['pts3d'] = pts3d_full + view["valid"] = pts3d_full.sum(dim=-1).isfinite() + + HR, WR = rgb_tensor.shape[1:] + _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR) + pts3d_rescaled = torch.from_numpy(pts3d_rescaled) + valid_rescaled = torch.from_numpy(valid_rescaled) + view['pts3d_rescaled'] = pts3d_rescaled + view["valid_rescaled"] = valid_rescaled + views.append(view) + return views diff --git a/imcui/third_party/dust3r/dust3r_visloc/datasets/utils.py b/imcui/third_party/dust3r/dust3r_visloc/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6053ae2e5ba6c0b0f5f014161b666623d6e0f3f5 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/datasets/utils.py @@ -0,0 +1,118 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dataset utilities +# -------------------------------------------------------- +import numpy as np +import quaternion +import torchvision.transforms as tvf +from dust3r.utils.geometry import geotrf + + +def cam_to_world_from_kapture(kdata, timestamp, camera_id): + camera_to_world = kdata.trajectories[timestamp, camera_id].inverse() + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = quaternion.as_rotation_matrix(camera_to_world.r) + camera_pose[:3, 3] = camera_to_world.t_raw + return camera_pose + + +ratios_resolutions = { + 224: {1.0: [224, 224]}, + 512: {4 / 3: [512, 384], 32 / 21: [512, 336], 16 / 9: [512, 288], 2 / 1: [512, 256], 16 / 5: [512, 160]} +} + + +def get_HW_resolution(H, W, maxdim, patchsize=16): + assert maxdim in ratios_resolutions, "Error, maxdim can only be 224 or 512 for now. Other maxdims not implemented yet." + ratios_resolutions_maxdim = ratios_resolutions[maxdim] + mindims = set([min(res) for res in ratios_resolutions_maxdim.values()]) + ratio = W / H + ref_ratios = np.array([*(ratios_resolutions_maxdim.keys())]) + islandscape = (W >= H) + if islandscape: + diff = np.abs(ratio - ref_ratios) + else: + diff = np.abs(ratio - (1 / ref_ratios)) + selkey = ref_ratios[np.argmin(diff)] + res = ratios_resolutions_maxdim[selkey] + # check patchsize and make sure output resolution is a multiple of patchsize + if isinstance(patchsize, tuple): + assert len(patchsize) == 2 and isinstance(patchsize[0], int) and isinstance( + patchsize[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints." + assert patchsize[0] == patchsize[1], "Error, non square patches not managed" + patchsize = patchsize[0] + assert max(res) == maxdim + assert min(res) in mindims + return res[::-1] if islandscape else res # return HW + + +def get_resize_function(maxdim, patch_size, H, W, is_mask=False): + if [max(H, W), min(H, W)] in ratios_resolutions[maxdim].values(): + return lambda x: x, np.eye(3), np.eye(3) + else: + target_HW = get_HW_resolution(H, W, maxdim=maxdim, patchsize=patch_size) + + ratio = W / H + target_ratio = target_HW[1] / target_HW[0] + to_orig_crop = np.eye(3) + to_rescaled_crop = np.eye(3) + if abs(ratio - target_ratio) < np.finfo(np.float32).eps: + crop_W = W + crop_H = H + elif ratio - target_ratio < 0: + crop_W = W + crop_H = int(W / target_ratio) + to_orig_crop[1, 2] = (H - crop_H) / 2.0 + to_rescaled_crop[1, 2] = -(H - crop_H) / 2.0 + else: + crop_W = int(H * target_ratio) + crop_H = H + to_orig_crop[0, 2] = (W - crop_W) / 2.0 + to_rescaled_crop[0, 2] = - (W - crop_W) / 2.0 + + crop_op = tvf.CenterCrop([crop_H, crop_W]) + + if is_mask: + resize_op = tvf.Resize(size=target_HW, interpolation=tvf.InterpolationMode.NEAREST_EXACT) + else: + resize_op = tvf.Resize(size=target_HW) + to_orig_resize = np.array([[crop_W / target_HW[1], 0, 0], + [0, crop_H / target_HW[0], 0], + [0, 0, 1]]) + to_rescaled_resize = np.array([[target_HW[1] / crop_W, 0, 0], + [0, target_HW[0] / crop_H, 0], + [0, 0, 1]]) + + op = tvf.Compose([crop_op, resize_op]) + + return op, to_rescaled_resize @ to_rescaled_crop, to_orig_crop @ to_orig_resize + + +def rescale_points3d(pts2d, pts3d, to_resize, HR, WR): + # rescale pts2d as floats + # to colmap, so that the image is in [0, D] -> [0, NewD] + pts2d = pts2d.copy() + pts2d[:, 0] += 0.5 + pts2d[:, 1] += 0.5 + + pts2d_rescaled = geotrf(to_resize, pts2d, norm=True) + + pts2d_rescaled_int = pts2d_rescaled.copy() + # convert back to cv2 before round [-0.5, 0.5] -> pixel 0 + pts2d_rescaled_int[:, 0] -= 0.5 + pts2d_rescaled_int[:, 1] -= 0.5 + pts2d_rescaled_int = pts2d_rescaled_int.round().astype(np.int64) + + # update valid (remove cropped regions) + valid_rescaled = (pts2d_rescaled_int[:, 0] >= 0) & (pts2d_rescaled_int[:, 0] < WR) & ( + pts2d_rescaled_int[:, 1] >= 0) & (pts2d_rescaled_int[:, 1] < HR) + + pts2d_rescaled_int = pts2d_rescaled_int[valid_rescaled] + + # rebuild pts3d from rescaled ps2d poses + pts3d_rescaled = np.full((HR, WR, 3), np.nan, dtype=np.float32) # pts3d in 512 x something + pts3d_rescaled[pts2d_rescaled_int[:, 1], pts2d_rescaled_int[:, 0]] = pts3d[valid_rescaled] + + return pts2d_rescaled, pts2d_rescaled_int, pts3d_rescaled, np.isfinite(pts3d_rescaled.sum(axis=-1)) diff --git a/imcui/third_party/dust3r/dust3r_visloc/evaluation.py b/imcui/third_party/dust3r/dust3r_visloc/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..027179f2b1007db558f57d3d67f48a6d7aa1ab9d --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/evaluation.py @@ -0,0 +1,65 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# evaluation utilities +# -------------------------------------------------------- +import numpy as np +import quaternion +import torch +import roma +import collections +import os + + +def aggregate_stats(info_str, pose_errors, angular_errors): + stats = collections.Counter() + median_pos_error = np.median(pose_errors) + median_angular_error = np.median(angular_errors) + out_str = f'{info_str}: {len(pose_errors)} images - {median_pos_error=}, {median_angular_error=}' + + for trl_thr, ang_thr in [(0.1, 1), (0.25, 2), (0.5, 5), (5, 10)]: + for pose_error, angular_error in zip(pose_errors, angular_errors): + correct_for_this_threshold = (pose_error < trl_thr) and (angular_error < ang_thr) + stats[trl_thr, ang_thr] += correct_for_this_threshold + stats = {f'acc@{key[0]:g}m,{key[1]}deg': 100 * val / len(pose_errors) for key, val in stats.items()} + for metric, perf in stats.items(): + out_str += f' - {metric:12s}={float(perf):.3f}' + return out_str + + +def get_pose_error(pr_camtoworld, gt_cam_to_world): + abs_transl_error = torch.linalg.norm(torch.tensor(pr_camtoworld[:3, 3]) - torch.tensor(gt_cam_to_world[:3, 3])) + abs_angular_error = roma.rotmat_geodesic_distance(torch.tensor(pr_camtoworld[:3, :3]), + torch.tensor(gt_cam_to_world[:3, :3])) * 180 / np.pi + return abs_transl_error, abs_angular_error + + +def export_results(output_dir, xp_label, query_names, poses_pred): + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + + lines = "" + lines_ltvl = "" + for query_name, pr_querycam_to_world in zip(query_names, poses_pred): + if pr_querycam_to_world is None: + pr_world_to_querycam = np.eye(4) + else: + pr_world_to_querycam = np.linalg.inv(pr_querycam_to_world) + query_shortname = os.path.basename(query_name) + pr_world_to_querycam_q = quaternion.from_rotation_matrix(pr_world_to_querycam[:3, :3]) + pr_world_to_querycam_t = pr_world_to_querycam[:3, 3] + + line_pose = quaternion.as_float_array(pr_world_to_querycam_q).tolist() + \ + pr_world_to_querycam_t.flatten().tolist() + + line_content = [query_name] + line_pose + lines += ' '.join(str(v) for v in line_content) + '\n' + + line_content_ltvl = [query_shortname] + line_pose + lines_ltvl += ' '.join(str(v) for v in line_content_ltvl) + '\n' + + with open(os.path.join(output_dir, xp_label + '_results.txt'), 'wt') as f: + f.write(lines) + with open(os.path.join(output_dir, xp_label + '_ltvl.txt'), 'wt') as f: + f.write(lines_ltvl) diff --git a/imcui/third_party/dust3r/dust3r_visloc/localization.py b/imcui/third_party/dust3r/dust3r_visloc/localization.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8ae198dc3479f12a976bab0bda692328880710 --- /dev/null +++ b/imcui/third_party/dust3r/dust3r_visloc/localization.py @@ -0,0 +1,140 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# main pnp code +# -------------------------------------------------------- +import numpy as np +import quaternion +import cv2 +from packaging import version + +from dust3r.utils.geometry import opencv_to_colmap_intrinsics + +try: + import poselib # noqa + HAS_POSELIB = True +except Exception as e: + HAS_POSELIB = False + +try: + import pycolmap # noqa + version_number = pycolmap.__version__ + if version.parse(version_number) < version.parse("0.5.0"): + HAS_PYCOLMAP = False + else: + HAS_PYCOLMAP = True +except Exception as e: + HAS_PYCOLMAP = False + +def run_pnp(pts2D, pts3D, K, distortion = None, mode='cv2', reprojectionError=5, img_size = None): + """ + use OPENCV model for distortion (4 values) + """ + assert mode in ['cv2', 'poselib', 'pycolmap'] + try: + if len(pts2D) > 4 and mode == "cv2": + confidence = 0.9999 + iterationsCount = 10_000 + if distortion is not None: + cv2_pts2ds = np.copy(pts2D) + cv2_pts2ds = cv2.undistortPoints(cv2_pts2ds, K, np.array(distortion), R=None, P=K) + pts2D = cv2_pts2ds.reshape((-1, 2)) + + success, r_pose, t_pose, _ = cv2.solvePnPRansac(pts3D, pts2D, K, None, flags=cv2.SOLVEPNP_SQPNP, + iterationsCount=iterationsCount, + reprojectionError=reprojectionError, + confidence=confidence) + if not success: + return False, None + r_pose = cv2.Rodrigues(r_pose)[0] # world2cam == world2cam2 + RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]] # world2cam2 + return True, np.linalg.inv(RT) # cam2toworld + elif len(pts2D) > 4 and mode == "poselib": + assert HAS_POSELIB + confidence = 0.9999 + iterationsCount = 10_000 + # NOTE: `Camera` struct currently contains `width`/`height` fields, + # however these are not used anywhere in the code-base and are provided simply to be consistent with COLMAP. + # so we put garbage in there + colmap_intrinsics = opencv_to_colmap_intrinsics(K) + fx = colmap_intrinsics[0, 0] + fy = colmap_intrinsics[1, 1] + cx = colmap_intrinsics[0, 2] + cy = colmap_intrinsics[1, 2] + width = img_size[0] if img_size is not None else int(cx*2) + height = img_size[1] if img_size is not None else int(cy*2) + + if distortion is None: + camera = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]} + else: + camera = {'model': 'OPENCV', 'width': width, 'height': height, + 'params': [fx, fy, cx, cy] + distortion} + + pts2D = np.copy(pts2D) + pts2D[:, 0] += 0.5 + pts2D[:, 1] += 0.5 + pose, _ = poselib.estimate_absolute_pose(pts2D, pts3D, camera, + {'max_reproj_error': reprojectionError, + 'max_iterations': iterationsCount, + 'success_prob': confidence}, {}) + if pose is None: + return False, None + RT = pose.Rt # (3x4) + RT = np.r_[RT, [(0,0,0,1)]] # world2cam + return True, np.linalg.inv(RT) # cam2toworld + elif len(pts2D) > 4 and mode == "pycolmap": + assert HAS_PYCOLMAP + assert img_size is not None + + pts2D = np.copy(pts2D) + pts2D[:, 0] += 0.5 + pts2D[:, 1] += 0.5 + colmap_intrinsics = opencv_to_colmap_intrinsics(K) + fx = colmap_intrinsics[0, 0] + fy = colmap_intrinsics[1, 1] + cx = colmap_intrinsics[0, 2] + cy = colmap_intrinsics[1, 2] + width = img_size[0] + height = img_size[1] + if distortion is None: + camera_dict = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]} + else: + camera_dict = {'model': 'OPENCV', 'width': width, 'height': height, + 'params': [fx, fy, cx, cy] + distortion} + + pycolmap_camera = pycolmap.Camera( + model=camera_dict['model'], width=camera_dict['width'], height=camera_dict['height'], + params=camera_dict['params']) + + pycolmap_estimation_options = dict(ransac=dict(max_error=reprojectionError, min_inlier_ratio=0.01, + min_num_trials=1000, max_num_trials=100000, + confidence=0.9999)) + pycolmap_refinement_options=dict(refine_focal_length=False, refine_extra_params=False) + ret = pycolmap.absolute_pose_estimation(pts2D, pts3D, pycolmap_camera, + estimation_options=pycolmap_estimation_options, + refinement_options=pycolmap_refinement_options) + if ret is None: + ret = {'success': False} + else: + ret['success'] = True + if callable(ret['cam_from_world'].matrix): + retmat = ret['cam_from_world'].matrix() + else: + retmat = ret['cam_from_world'].matrix + ret['qvec'] = quaternion.from_rotation_matrix(retmat[:3, :3]) + ret['tvec'] = retmat[:3, 3] + + if not (ret['success'] and ret['num_inliers'] > 0): + success = False + pose = None + else: + success = True + pr_world_to_querycam = np.r_[ret['cam_from_world'].matrix(), [(0,0,0,1)]] + pose = np.linalg.inv(pr_world_to_querycam) + return success, pose + else: + return False, None + except Exception as e: + print(f'error during pnp: {e}') + return False, None \ No newline at end of file diff --git a/imcui/third_party/dust3r/train.py b/imcui/third_party/dust3r/train.py new file mode 100644 index 0000000000000000000000000000000000000000..503e63572376c259e6b259850e19c3f6036aa535 --- /dev/null +++ b/imcui/third_party/dust3r/train.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# training executable for DUSt3R +# -------------------------------------------------------- +from dust3r.training import get_args_parser, train + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + train(args) diff --git a/imcui/third_party/dust3r/visloc.py b/imcui/third_party/dust3r/visloc.py new file mode 100644 index 0000000000000000000000000000000000000000..6411b3eaf96dea961f9524e887a12d92f2012c6b --- /dev/null +++ b/imcui/third_party/dust3r/visloc.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Simple visloc script +# -------------------------------------------------------- +import numpy as np +import random +import argparse +from tqdm import tqdm +import math + +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf + +from dust3r_visloc.datasets import * +from dust3r_visloc.localization import run_pnp +from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval") + parser_weights = parser.add_mutually_exclusive_group(required=True) + parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) + parser_weights.add_argument("--model_name", type=str, help="name of the model weights", + choices=["DUSt3R_ViTLarge_BaseDecoder_512_dpt", + "DUSt3R_ViTLarge_BaseDecoder_512_linear", + "DUSt3R_ViTLarge_BaseDecoder_224_linear"]) + parser.add_argument("--confidence_threshold", type=float, default=3.0, + help="confidence values higher than threshold are invalid") + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'], + help="pnp lib to use") + parser_reproj = parser.add_mutually_exclusive_group() + parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error") + parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None, + help="pnp reprojection error as a ratio of the diagonal of the image") + + parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept") + parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") + + parser.add_argument("--output_dir", type=str, default=None, help="output path") + parser.add_argument("--output_label", type=str, default='', help="prefix for results files") + return parser + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + conf_thr = args.confidence_threshold + device = args.device + pnp_mode = args.pnp_mode + reprojection_error = args.reprojection_error + reprojection_error_diag_ratio = args.reprojection_error_diag_ratio + pnp_max_points = args.pnp_max_points + viz_matches = args.viz_matches + + if args.weights is not None: + weights_path = args.weights + else: + weights_path = "naver/" + args.model_name + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + dataset = eval(args.dataset) + dataset.set_resolution(model) + + query_names = [] + poses_pred = [] + pose_errors = [] + angular_errors = [] + for idx in tqdm(range(len(dataset))): + views = dataset[(idx)] # 0 is the query + query_view = views[0] + map_views = views[1:] + query_names.append(query_view['image_name']) + + query_pts2d = [] + query_pts3d = [] + for map_view in map_views: + # prepare batch + imgs = [] + for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]): + imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]), + idx=idx, instance=str(idx))) + output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False) + pred1, pred2 = output['pred1'], output['pred2'] + confidence_masks = [pred1['conf'].squeeze(0) >= conf_thr, + (pred2['conf'].squeeze(0) >= conf_thr) & map_view['valid_rescaled']] + pts3d = [pred1['pts3d'].squeeze(0), pred2['pts3d_in_other_view'].squeeze(0)] + + # find 2D-2D matches between the two images + pts2d_list, pts3d_list = [], [] + for i in range(2): + conf_i = confidence_masks[i].cpu().numpy() + true_shape_i = imgs[i]['true_shape'][0] + pts2d_list.append(xy_grid(true_shape_i[1], true_shape_i[0])[conf_i]) + pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) + + PQ, PM = pts3d_list[0], pts3d_list[1] + if len(PQ) == 0 or len(PM) == 0: + continue + reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(PQ, PM) + if viz_matches > 0: + print(f'found {num_matches} matches') + matches_im1 = pts2d_list[1][reciprocal_in_PM] + matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] + valid_pts3d = map_view['pts3d_rescaled'][matches_im1[:, 1], matches_im1[:, 0]] + + # from cv2 to colmap + matches_im0 = matches_im0.astype(np.float64) + matches_im1 = matches_im1.astype(np.float64) + matches_im0[:, 0] += 0.5 + matches_im0[:, 1] += 0.5 + matches_im1[:, 0] += 0.5 + matches_im1[:, 1] += 0.5 + # rescale coordinates + matches_im0 = geotrf(query_view['to_orig'], matches_im0, norm=True) + matches_im1 = geotrf(query_view['to_orig'], matches_im1, norm=True) + # from colmap back to cv2 + matches_im0[:, 0] -= 0.5 + matches_im0[:, 1] -= 0.5 + matches_im1[:, 0] -= 0.5 + matches_im1[:, 1] -= 0.5 + + # visualize a few matches + if viz_matches > 0: + viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] + from matplotlib import pyplot as pl + n_viz = viz_matches + match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) + viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] + + H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] + img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) + img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) + img = np.concatenate((img0, img1), axis=1) + pl.figure() + pl.imshow(img) + cmap = pl.get_cmap('jet') + for i in range(n_viz): + (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T + pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) + pl.show(block=True) + + if len(valid_pts3d) == 0: + pass + else: + query_pts3d.append(valid_pts3d.cpu().numpy()) + query_pts2d.append(matches_im0) + + if len(query_pts2d) == 0: + success = False + pr_querycam_to_world = None + else: + query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) + query_pts3d = np.concatenate(query_pts3d, axis=0) + if len(query_pts2d) > pnp_max_points: + idxs = random.sample(range(len(query_pts2d)), pnp_max_points) + query_pts3d = query_pts3d[idxs] + query_pts2d = query_pts2d[idxs] + + W, H = query_view['rgb'].size + if reprojection_error_diag_ratio is not None: + reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2) + else: + reprojection_error_img = reprojection_error + success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d, + query_view['intrinsics'], query_view['distortion'], + pnp_mode, reprojection_error_img, img_size=[W, H]) + + if not success: + abs_transl_error = float('inf') + abs_angular_error = float('inf') + else: + abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world']) + + pose_errors.append(abs_transl_error) + angular_errors.append(abs_angular_error) + poses_pred.append(pr_querycam_to_world) + + xp_label = f'tol_conf_{conf_thr}' + if args.output_label: + xp_label = args.output_label + '_' + xp_label + if reprojection_error_diag_ratio is not None: + xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}' + else: + xp_label = xp_label + f'_reproj_err_{reprojection_error}' + export_results(args.output_dir, xp_label, query_names, poses_pred) + out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors) + print(out_string) diff --git a/imcui/third_party/gim/analysis.py b/imcui/third_party/gim/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2ece6c2056c167d8f9d5ada9a68d03ee1a9f97 --- /dev/null +++ b/imcui/third_party/gim/analysis.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import os +import argparse + +import numpy as np + +from os.path import join +from datetime import datetime + +angular_thresholds = ['5.0°'] +dist_thresholds = ['0.1m'] +intt = lambda x: list(map(int, x)) +floatt = lambda x: list(map(float, x)) +strr = lambda x: list(map(lambda x:f'{x:.18f}', x)) + +datasets = [ + 'GL3D', + 'BlendedMVS', + 'ETH3DI', + 'ETH3DO', + 'KITTI', + 'RobotcarWeather', + 'RobotcarSeason', + 'RobotcarNight', + 'Multi-FoV', + 'SceneNetRGBD', + 'ICL-NUIM', + 'GTA-SfM', +] + + +def error_auc(errs0, errs1, thres, metric): + if isinstance(errs0, list): errs0 = np.array(errs0) + if isinstance(errs1, list): errs1 = np.array(errs1) + if any(np.isnan(errs0)): errs0[np.isnan(errs0)] = 180 + if any(np.isnan(errs1)): errs1[np.isnan(errs1)] = 180 + if any(np.isinf(errs0)): errs0[np.isinf(errs0)] = 180 + if any(np.isinf(errs1)): errs1[np.isinf(errs1)] = 180 + errors = np.max(np.stack([errs0, errs1]), axis=0) + errors = [0] + sorted(list(errors)) + recall = list(np.linspace(0, 1, len(errors))) + + aucs = [] + for thr in thres: + thr = float(thr[:-1]) + last_index = np.searchsorted(errors, thr) + y = recall[:last_index] + [recall[last_index-1]] + x = errors[:last_index] + [thr] + aucs.append(np.trapz(y, x) / thr) + + return {f'{metric}@ {t}': auc for t, auc in zip(thres, aucs)} + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--dir', type=str, default='.') + parser.add_argument('--wid', type=str, required=True) + parser.add_argument('--version', type=str, default=None) + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--log', action='store_true') + parser.add_argument('--sceids', type=str, choices=datasets, nargs='+', + default=None, help=f'Test Datasets: {datasets}', ) + opt = parser.parse_args() + + dir = opt.dir + wid = opt.wid + version = opt.version + + _data = \ + { + x.rpartition('.txt')[0].split()[2]:x for x in + [ + d for d in os.listdir(dir) if not os.path.isdir(os.path.join(dir, d)) + ] if wid == x.rpartition('.txt')[0].split()[1] and version is not None and version == x.rpartition('.txt')[0].split()[-1] + } + _data = {k:_data[k] for k in datasets if k in _data.keys()} + + sceids = opt.sceids + sceids = sceids if sceids is not None else _data.keys() + results = {} + for sceid in sceids: + results[sceid] = {} + if not opt.verbose: print('{:^13} {}'.format(sceid, wid)) + + # read txt + with open(join(dir, _data[sceid]), 'r') as f: + data = f.readlines() + head = data[0].split() + content = [x.split() for x in data[1:]] + details = {k: [] for k in head[3:]} + + stacks = [] + for x in content: + ids = x[0] + if ids in stacks: continue + + for k, v in zip(head[3:], x[3:]): details[k].append(v) + stacks.append(ids) + + mAP = error_auc(floatt(details['R_errs']), floatt(details['t_errs']), angular_thresholds, 'auc') + for k, v in mAP.items(): results[sceid][k] = v + + # print head + output = '' + + num = 56+25*len(sceids) + output += '='*num + output += "\n" + + output += '{:<25}'.format(datetime.now().strftime("%Y-%m-%d, %H:%M:%S")) + output += '{:<15} '.format('Model') + output += '{:<14} '.format('Metric') + for sceid in sceids: output += '{:<25} '.format(sceid) + output += "\n" + + output += '-'*num + output += "\n" + + for k in list(results.values())[0].keys(): + output += '{:<25}'.format(datetime.now().strftime("%Y-%m-%d, %H:%M:%S")) if opt.log else '{:<25}'.format(' ') + output += '{:<15} '.format(wid) + output += '{:<14} '.format(k) + + for sceid in sceids: + output += '{:<25} '.format(results[sceid][k]) + output += "\n" + + output += '='*num + output += "\n" + output += "\n" + + if opt.verbose: + print(output) + + if opt.log: + path = 'ANALYSIS RESULTS.txt' + with open(path, 'a') as file: + file.write(output) diff --git a/imcui/third_party/gim/check.py b/imcui/third_party/gim/check.py new file mode 100644 index 0000000000000000000000000000000000000000..0289b35df45338c25717f5c24452ba2bc4104b5d --- /dev/null +++ b/imcui/third_party/gim/check.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import csv +from os import listdir +from os.path import join + +home = join('dump', 'zeb') + +# specified_key2 = "GL3D" +specified_keys = [ + 'GL3D', 'KITTI', 'ETH3DI', 'ETH3DO', 'GTASfM', 'ICLNUIM', 'MultiFoV', + 'SceneNet', 'BlendedMVS', 'RobotcarNight', 'RobotcarSeason', 'RobotcarWeather' +] + +for specified_key2 in specified_keys: + identifiers_dict = {} + + for filename in listdir(home): + if filename.endswith(".txt") and ']' in filename: + parts = filename[:-4].split() + if parts[2] == specified_key2: + with open(join(home, filename), 'r') as f: + reader = csv.reader(f, delimiter=' ') + file_identifiers = [row[0] for row in reader if row] + identifiers_dict[filename] = file_identifiers + + all_identical = True + reference_identifiers = None + if identifiers_dict: + reference_identifiers = list(identifiers_dict.values())[0] + for identifiers in identifiers_dict.values(): + if identifiers != reference_identifiers: + all_identical = False + break + + if all_identical: + print("Good ! all {} file identifiers is same".format(specified_key2)) + else: + print("Bad ! file {} have different identifiers".format(specified_key2)) + + if not all_identical: + for filename, identifiers in identifiers_dict.items(): + if identifiers != reference_identifiers: + print(f"File {filename} 's {specified_key2} identifiers is different with others") diff --git a/imcui/third_party/gim/datasets/augment.py b/imcui/third_party/gim/datasets/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf228e3aebe0e2ea238e6a0cf951472cf498616 --- /dev/null +++ b/imcui/third_party/gim/datasets/augment.py @@ -0,0 +1,53 @@ +import albumentations as A + + +class DarkAug(object): + """ + Extreme dark augmentation aiming at Aachen Day-Night + """ + + def __init__(self) -> None: + self.augmentor = A.Compose([ + A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) + ], p=0.75) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +class MobileAug(object): + """ + Random augmentations aiming at images of mobile/handhold devices. + """ + + def __init__(self): + self.augmentor = A.Compose([ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25) + ], p=1.0) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +def build_augmentor(method=None, **kwargs): + if method == 'dark': + return DarkAug() + elif method == 'mobile': + return MobileAug() + elif method is None: + return None + else: + raise ValueError(f'Invalid augmentation method: {method}') + + +if __name__ == '__main__': + augmentor = build_augmentor('FDA') diff --git a/imcui/third_party/gim/datasets/blendedmvs/__init__.py b/imcui/third_party/gim/datasets/blendedmvs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e018cf1cb63a866808cdf420fb4f360020517f4 --- /dev/null +++ b/imcui/third_party/gim/datasets/blendedmvs/__init__.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from os.path import join +from yacs.config import CfgNode as CN + +########################################## +#++++++++++++++++++++++++++++++++++++++++# +#+ +# +#+ BlendedMVS +# +#+ +# +#++++++++++++++++++++++++++++++++++++++++# +########################################## + +_CN = CN() + +_CN.DATASET = CN() + +DATA_ROOT = 'data/GL3D/' +NPZ_ROOT = DATA_ROOT + +_CN.NJOBS = 8 + +# TRAIN +_CN.DATASET.TRAIN = CN() +_CN.DATASET.TRAIN.PADDING = None +_CN.DATASET.TRAIN.DATA_ROOT = None +_CN.DATASET.TRAIN.NPZ_ROOT = None +_CN.DATASET.TRAIN.MAX_SAMPLES = None +_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.MAX_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.AUGMENTATION_TYPE = None +_CN.DATASET.TRAIN.LIST_PATH = None + +# VALID +_CN.DATASET.VALID = CN() +_CN.DATASET.VALID.PADDING = None +_CN.DATASET.VALID.DATA_ROOT = None +_CN.DATASET.VALID.NPZ_ROOT = None +_CN.DATASET.VALID.MAX_SAMPLES = None +_CN.DATASET.VALID.MIN_OVERLAP_SCORE = None +_CN.DATASET.VALID.MAX_OVERLAP_SCORE = None +_CN.DATASET.VALID.AUGMENTATION_TYPE = None +_CN.DATASET.VALID.LIST_PATH = None + +# TESTS +_CN.DATASET.TESTS = CN() +_CN.DATASET.TESTS.PADDING = False +_CN.DATASET.TESTS.DATA_ROOT = DATA_ROOT +_CN.DATASET.TESTS.NPZ_ROOT = NPZ_ROOT +_CN.DATASET.TESTS.MAX_SAMPLES = 64 +_CN.DATASET.TESTS.MIN_OVERLAP_SCORE = 0.0 +_CN.DATASET.TESTS.MAX_OVERLAP_SCORE = 0.5 +_CN.DATASET.TESTS.AUGMENTATION_TYPE = None +_CN.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/BlendedMVS.txt' + +cfg = _CN diff --git a/imcui/third_party/gim/datasets/data.py b/imcui/third_party/gim/datasets/data.py new file mode 100644 index 0000000000000000000000000000000000000000..537d2133b5be676785200ae0b82fdf4982d3f055 --- /dev/null +++ b/imcui/third_party/gim/datasets/data.py @@ -0,0 +1,216 @@ +import os +import torch +import pytorch_lightning as pl +from tqdm import tqdm +from joblib import Parallel, delayed +from torch.utils.data.dataset import Dataset +from torch.utils.data import DataLoader, ConcatDataset +from datasets.augment import build_augmentor +from tools.misc import tqdm_joblib + +from .gl3d.gl3d import GL3DDataset +from .gtasfm.gtasfm import GTASfMDataset +from .multifov.multifov import MultiFoVDataset +from .gl3d.gl3d import GL3DDataset as BlendedMVSDataset +from .iclnuim.iclnuim import ICLNUIMDataset +from .scenenet.scenenet import SceneNetDataset +from .eth3d.eth3d import ETH3DDataset +from .kitti.kitti import KITTIDataset +from .robotcar.robotcar import RobotcarDataset + +Benchmarks = dict( + GL3D = GL3DDataset, + GTASfM = GTASfMDataset, + MultiFoV = MultiFoVDataset, + BlendedMVS = BlendedMVSDataset, + ICLNUIM = ICLNUIMDataset, + SceneNet = SceneNetDataset, + ETH3DO = ETH3DDataset, + ETH3DI = ETH3DDataset, + KITTI = KITTIDataset, + RobotcarNight = RobotcarDataset, + RobotcarSeason = RobotcarDataset, + RobotcarWeather = RobotcarDataset, +) + + +class MultiSceneDataModule(pl.LightningDataModule): + """ + For distributed training, each training process is assgined + only a part of the training scenes to reduce memory overhead. + """ + + def __init__(self, args, dcfg): + """ + + Args: + args: (ArgumentParser) The only useful args is args.trains and args.valids + each one is a list, which contain like [PhotoTourism, MegaDepth,...] + We should traverse each item in args.trains and args.valids to build + self.train_datasets and self.valid_datasets + dcfg: (yacs) It contain all configs for each benchmark in args.trains and + args.valids + """ + super().__init__() + + self.args = args + self.dcfg = dcfg + self.train_loader_params = {'batch_size': args.batch_size, + 'shuffle': True, + 'num_workers': args.threads, + 'pin_memory': True, + 'drop_last': True} + self.valid_loader_params = {'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': args.threads, + 'pin_memory': True, + 'drop_last': False} + self.tests_loader_params = {'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': args.threads, + 'pin_memory': True, + 'drop_last': False} + + def setup(self, stage=None): + """ + Setup train/valid/test dataset. This method will be called by PL automatically. + Args: + stage (str): 'fit' in training phase, and 'test' in testing phase. + """ + + self.gpus = self.trainer.gpus + self.gpuid = self.trainer.global_rank + + self.train_datasets = None + self.valid_datasets = None + self.tests_datasets = None + + # TRAIN + if stage == 'fit': + train_datasets = [] + for benchmark in self.args.trains: + dcfg = self.dcfg.get(benchmark, None) + assert dcfg is not None, "Training dcfg is None" + + datasets = self._setup_dataset( + benchmark=benchmark, + data_root=dcfg.DATASET.TRAIN.DATA_ROOT, + npz_root=dcfg.DATASET.TRAIN.NPZ_ROOT, + scene_list_path=dcfg.DATASET.TRAIN.LIST_PATH, + df=self.dcfg.DF, + padding=dcfg.DATASET.TRAIN.PADDING, + min_overlap_score=dcfg.DATASET.TRAIN.MIN_OVERLAP_SCORE, + max_overlap_score=dcfg.DATASET.TRAIN.MAX_OVERLAP_SCORE, + max_resize=self.args.img_size, + augment_fn=build_augmentor(dcfg.DATASET.TRAIN.AUGMENTATION_TYPE), + max_samples=dcfg.DATASET.TRAIN.MAX_SAMPLES, + mode='train', + njobs=dcfg.NJOBS, + cfg=dcfg.DATASET.TRAIN, + ) + train_datasets += datasets + self.train_datasets = ConcatDataset(train_datasets) + os.environ['TOTAL_TRAIN_SAMPLES'] = str(len(self.train_datasets)) + + # VALID + valid_datasets = [] + for benchmark in self.args.valids: + dcfg = self.dcfg.get(benchmark, None) + assert dcfg is not None, "Validing dcfg is None" + + datasets = self._setup_dataset( + benchmark=benchmark, + data_root=dcfg.DATASET.VALID.DATA_ROOT, + npz_root=dcfg.DATASET.VALID.NPZ_ROOT, + scene_list_path=dcfg.DATASET.VALID.LIST_PATH, + df=self.dcfg.DF, + padding=dcfg.DATASET.VALID.PADDING, + min_overlap_score=dcfg.DATASET.VALID.MIN_OVERLAP_SCORE, + max_overlap_score=dcfg.DATASET.VALID.MAX_OVERLAP_SCORE, + max_resize=self.args.img_size, + augment_fn=build_augmentor(dcfg.DATASET.VALID.AUGMENTATION_TYPE), + max_samples=dcfg.DATASET.VALID.MAX_SAMPLES, + mode='valid', + njobs=dcfg.NJOBS, + cfg=dcfg.DATASET.VALID, + ) + valid_datasets += datasets + self.valid_datasets = ConcatDataset(valid_datasets) + os.environ['TOTAL_VALID_SAMPLES'] = str(len(self.valid_datasets)) + + # TEST + if stage == 'test': + tests_datasets = [] + for benchmark in [self.args.tests]: + dcfg = self.dcfg.get(benchmark, None) + assert dcfg is not None, "Validing dcfg is None" + + datasets = self._setup_dataset( + benchmark=benchmark, + data_root=dcfg.DATASET.TESTS.DATA_ROOT, + npz_root=dcfg.DATASET.TESTS.NPZ_ROOT, + scene_list_path=dcfg.DATASET.TESTS.LIST_PATH, + df=self.dcfg.DF, + padding=dcfg.DATASET.TESTS.PADDING, + min_overlap_score=dcfg.DATASET.TESTS.MIN_OVERLAP_SCORE, + max_overlap_score=dcfg.DATASET.TESTS.MAX_OVERLAP_SCORE, + max_resize=self.args.img_size, + augment_fn=build_augmentor(dcfg.DATASET.TESTS.AUGMENTATION_TYPE), + max_samples=dcfg.DATASET.TESTS.MAX_SAMPLES, + mode='test', + njobs=dcfg.NJOBS, + cfg=dcfg.DATASET.TESTS, + ) + tests_datasets += datasets + self.tests_datasets = ConcatDataset(tests_datasets) + os.environ['TOTAL_TESTS_SAMPLES'] = str(len(self.tests_datasets)) + if self.gpuid == 0: print('TOTAL_TESTS_SAMPLES:', len(self.tests_datasets)) + + def _setup_dataset(self, benchmark, data_root, npz_root, scene_list_path, df, padding, + min_overlap_score, max_overlap_score, max_resize, augment_fn, + max_samples, mode, njobs, cfg): + + seq_names = [benchmark.lower()] + + with tqdm_joblib(tqdm(bar_format="{l_bar}{bar:3}{r_bar}", ncols=100, + desc=f'[GPU {self.gpuid}] load {mode} {benchmark:14} data', + total=len(seq_names), disable=int(self.gpuid) != 0)): + datasets = Parallel(n_jobs=njobs)( + delayed(lambda x: _build_dataset( + Benchmarks.get(benchmark), + root_dir=data_root, + npz_root=npz_root, + seq_name=x, + mode=mode, + min_overlap_score=min_overlap_score, + max_overlap_score=max_overlap_score, + max_resize=max_resize, + df=df, + padding=padding, + augment_fn=augment_fn, + max_samples=max_samples, + **cfg + ))(seqname) for seqname in seq_names) + return datasets + + def train_dataloader(self, *args, **kwargs): + return DataLoader(self.train_datasets, collate_fn=collate_fn, **self.train_loader_params) + + def valid_dataloader(self, *args, **kwargs): + return DataLoader(self.valid_datasets, collate_fn=collate_fn, **self.valid_loader_params) + + def val_dataloader(self, *args, **kwargs): + return self.valid_dataloader(*args, **kwargs) + + def test_dataloader(self, *args, **kwargs): + return DataLoader(self.tests_datasets, collate_fn=collate_fn, **self.tests_loader_params) + + +def collate_fn(batch): + batch = list(filter(lambda x: x is not None, batch)) + return torch.utils.data.dataloader.default_collate(batch) + + +def _build_dataset(dataset: Dataset, *args, **kwargs): + # noinspection PyCallingNonCallable + return dataset(*args, **kwargs) diff --git a/imcui/third_party/gim/datasets/dataset.py b/imcui/third_party/gim/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..96bc1e5adc539fd9a10243160e6507a1572ba086 --- /dev/null +++ b/imcui/third_party/gim/datasets/dataset.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import torch + +from torch.utils.data import Dataset + + +class RGBDDataset(Dataset): + def __getitem__(self, idx): + + data = { + # image 0 + 'image0': None, + 'color0': None, + 'imsize0': None, + 'resize0': None, + + # image 1 + 'image1': None, + 'color1': None, + 'imsize1': None, + 'resize1': None, + + 'pseudo_labels': torch.zeros((100000, 4), dtype=torch.float), + 'gt': True, + 'zs': False, + + # image transform + 'T_0to1': None, + 'T_1to0': None, + 'K0': None, + 'K1': None, + # pair information + 'scale0': None, + 'scale1': None, + 'dataset_name': None, + 'scene_id': None, + 'pair_id': None, + 'pair_names': None, + 'covisible0': None, + 'covisible1': None, + # ETH3D dataset + 'K0_': torch.zeros(12, dtype=torch.float), + 'K1_': torch.zeros(12, dtype=torch.float), + # Hq + 'Hq_aug': torch.eye(3, dtype=torch.float), + 'Hq_ori': torch.eye(3, dtype=torch.float), + } + + return data diff --git a/imcui/third_party/gim/datasets/eth3d/__init__.py b/imcui/third_party/gim/datasets/eth3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..724771b381e9283ac5a48cd929fb111ed93fd0a2 --- /dev/null +++ b/imcui/third_party/gim/datasets/eth3d/__init__.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from os.path import join +from yacs.config import CfgNode as CN + +########################################## +#++++++++++++++++++++++++++++++++++++++++# +#+ +# +#+ ETH3D +# +#+ +# +#++++++++++++++++++++++++++++++++++++++++# +########################################## + +_CN = CN() + +_CN.DATASET = CN() + +DATA_ROOT = 'data/ETH3D/' +NPZ_ROOT = DATA_ROOT + +_CN.NJOBS = 1 + +# TRAIN +_CN.DATASET.TRAIN = CN() +_CN.DATASET.TRAIN.PADDING = None +_CN.DATASET.TRAIN.DATA_ROOT = None +_CN.DATASET.TRAIN.NPZ_ROOT = None +_CN.DATASET.TRAIN.MAX_SAMPLES = None +_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.MAX_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.AUGMENTATION_TYPE = None +_CN.DATASET.TRAIN.LIST_PATH = None + +# VALID +_CN.DATASET.VALID = CN() +_CN.DATASET.VALID.PADDING = None +_CN.DATASET.VALID.DATA_ROOT = None +_CN.DATASET.VALID.NPZ_ROOT = None +_CN.DATASET.VALID.MAX_SAMPLES = None +_CN.DATASET.VALID.MIN_OVERLAP_SCORE = None +_CN.DATASET.VALID.MAX_OVERLAP_SCORE = None +_CN.DATASET.VALID.AUGMENTATION_TYPE = None +_CN.DATASET.VALID.LIST_PATH = None + +# TESTS +_CN.DATASET.TESTS = CN() +_CN.DATASET.TESTS.PADDING = True +_CN.DATASET.TESTS.DATA_ROOT = DATA_ROOT +_CN.DATASET.TESTS.NPZ_ROOT = NPZ_ROOT +_CN.DATASET.TESTS.MAX_SAMPLES = 10000 +_CN.DATASET.TESTS.MIN_OVERLAP_SCORE = 0.0 +_CN.DATASET.TESTS.MAX_OVERLAP_SCORE = 0.5 +_CN.DATASET.TESTS.AUGMENTATION_TYPE = None +_CN.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/ETH3DO.txt' + +cfgO = _CN + +cfgI = cfgO.clone() + +cfgI.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/ETH3DI.txt' diff --git a/imcui/third_party/gim/datasets/eth3d/eth3d.py b/imcui/third_party/gim/datasets/eth3d/eth3d.py new file mode 100644 index 0000000000000000000000000000000000000000..f30b7e3e19b58638791112849320d2a674354e93 --- /dev/null +++ b/imcui/third_party/gim/datasets/eth3d/eth3d.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import glob +import torch +import imagesize +import torch.nn.functional as F + + +from os.path import join + +from torch.utils.data import Dataset + +from .utils import read_images + + +class ETH3DDataset(Dataset): + def __init__(self, + root_dir, # data root dit + npz_root, # data info, like, overlap, image_path, depth_path + seq_name, # current sequence + mode, # train or val or test + min_overlap_score, + max_overlap_score, + max_resize, # max edge after resize + df, # general is 8 for ResNet w/ pre 3-layers + padding, # padding image for batch training + augment_fn, # augmentation function + max_samples, # max sample in current sequence + **kwargs): + super().__init__() + + self.root = join('zeb', seq_name) + + paths = glob.glob(join(self.root, '*.txt')) + + lines = [] + for path in paths: + with open(path, 'r') as file: + scene_id = path.rpartition('/')[-1].rpartition('.')[0].split('-')[0] + line = file.readline().strip().split() + lines.append([scene_id] + line) + + self.pairs = sorted(lines) + + self.scale = 1 / df + + self.df = df + self.max_resize = max_resize + self.padding = padding + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + pair = self.pairs[idx] + + scene_id = pair[0] + + img_name0 = pair[1].rpartition('.')[0] + img_name1 = pair[2].rpartition('.')[0] + + img_path0 = join(self.root, '{}-{}.png'.format(scene_id, img_name0)) + img_path1 = join(self.root, '{}-{}.png'.format(scene_id, img_name1)) + + width0, height0 = imagesize.get(img_path0) + width1, height1 = imagesize.get(img_path1) + + image0, color0, scale0, resize0, mask0 = read_images( + img_path0, self.max_resize, self.df, self.padding, None) + image1, color1, scale1, resize1, mask1 = read_images( + img_path1, self.max_resize, self.df, self.padding, None) + + K0 = torch.tensor(list(map(float, pair[5:14])), dtype=torch.float).reshape(3, 3) + K1 = torch.tensor(list(map(float, pair[14:23])), dtype=torch.float).reshape(3, 3) + + # read image size + imsize0 = torch.tensor([height0, width0], dtype=torch.long) + imsize1 = torch.tensor([height1, width1], dtype=torch.long) + resize0 = torch.tensor(resize0, dtype=torch.long) + resize1 = torch.tensor(resize1, dtype=torch.long) + + # read and compute relative poses + T_0to1 = torch.tensor(list(map(float, pair[23:])), dtype=torch.float).reshape(4, 4) + + data = { + # image 0 + 'image0': image0, # (1, 3, h, w) + 'color0': color0, # (1, h, w) + 'imsize0': imsize0, # (2) - 2:(h, w) + 'resize0': resize0, # (2) - 2:(h, w) + + # image 1 + 'image1': image1, + 'color1': color1, + 'imsize1': imsize1, # (2) - 2:[h, w] + 'resize1': resize1, # (2) - 2:(h, w) + + # image transform + 'T_0to1': T_0to1, # (4, 4) + 'K0': K0, # (3, 3) + 'K1': K1, + # pair information + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'ETH3D', + 'scene_id': scene_id, + 'pair_id': f'{idx}-{idx}', + 'pair_names': (img_name0+'.JPG', + img_name1+'.JPG'), + 'covisible0': float(pair[3]), + 'covisible1': float(pair[4]), + } + + if mask0 is not None: # img_padding is True + if self.scale: + # noinspection PyArgumentList + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + # noinspection PyUnboundLocalVariable + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/gim/datasets/eth3d/utils.py b/imcui/third_party/gim/datasets/eth3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a2defebb8b15bb5bb4260d613fbe00795bc4c381 --- /dev/null +++ b/imcui/third_party/gim/datasets/eth3d/utils.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import cv2 +import math +import torch + +import numpy as np + +from datasets.utils import imread_color, get_resized_wh + + +def World_to_Camera(image_pose): + qvec = image_pose[:4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + + R = np.array([ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ] + ]) + + t = image_pose[4:7] + + # World-to-Camera pose + current_pose = np.zeros([4, 4]) + current_pose[: 3, : 3] = R + current_pose[: 3, 3] = t + current_pose[3, 3] = 1 + return current_pose + + +def read_depth(filename): + # read 4-byte float from file + with open(filename, 'rb') as f: + depth = np.fromfile(f, dtype=np.float32) + return depth + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + h = pad_size[0] + h = math.ceil(h / 8) * 8 + pad_size = (h, pad_size[1]) + # assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size[0], pad_size[1]), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + elif inp.ndim == 3: + padded = np.zeros((pad_size[0], pad_size[1], inp.shape[-1]), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + else: + raise NotImplementedError() + + if ret_mask: + mask = np.zeros((pad_size[0], pad_size[1]), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + + return padded, mask + + +def read_images(path, max_resize, df, padding, augment_fn=None, image=None): + """ + Args: + path: string + max_resize (int): max image size after resied + df (int, optional): image size division factor. + NOTE: this will change the final image size after img_resize + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + image: RGB image + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + assert max_resize is not None + + image = imread_color(path, augment_fn) if image is None else image # (w,h,3) image is RGB + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + # resize image + w, h = image.shape[1], image.shape[0] + if max(w, h) > max_resize: + w_new, h_new = get_resized_wh(w, h, max_resize) # make max(w, h) to max_size + else: + w_new, h_new = w, h + + # w_new, h_new = get_divisible_wh(w_new, h_new, df) # make image divided by df and must <= max_size + image = cv2.resize(image, (w_new, h_new)) # (w',h',3) + gray = cv2.resize(gray, (w_new, h_new)) # (w',h',3) + scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) + + # padding + mask = None + if padding: + image, _ = pad_bottom_right(image, (int(max_resize/1.5), max_resize), ret_mask=False) + gray, mask = pad_bottom_right(gray, (int(max_resize/1.5), max_resize), ret_mask=True) + mask = torch.from_numpy(mask) + + gray = torch.from_numpy(gray).float()[None] / 255 # (1,h,w) + image = torch.from_numpy(image).float() / 255 # (h,w,3) + image = image.permute(2,0,1) # (3,h,w) + + resize = [h_new, w_new] + + return gray, image, scale, resize, mask diff --git a/imcui/third_party/gim/datasets/gl3d/__init__.py b/imcui/third_party/gim/datasets/gl3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f37165aeb4cf3d3cea2d9ed1c271a998b226fee7 --- /dev/null +++ b/imcui/third_party/gim/datasets/gl3d/__init__.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from os.path import join +from yacs.config import CfgNode as CN + +########################################## +#++++++++++++++++++++++++++++++++++++++++# +#+ +# +#+ GL3D +# +#+ +# +#++++++++++++++++++++++++++++++++++++++++# +########################################## + +_CN = CN() + +_CN.DATASET = CN() + +DATA_ROOT = 'data/GL3D/' +NPZ_ROOT = DATA_ROOT + +_CN.NJOBS = 8 + +# TRAIN +_CN.DATASET.TRAIN = CN() +_CN.DATASET.TRAIN.PADDING = None +_CN.DATASET.TRAIN.DATA_ROOT = None +_CN.DATASET.TRAIN.NPZ_ROOT = None +_CN.DATASET.TRAIN.MAX_SAMPLES = None +_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.MAX_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.AUGMENTATION_TYPE = None +_CN.DATASET.TRAIN.LIST_PATH = None + +# VALID +_CN.DATASET.VALID = CN() +_CN.DATASET.VALID.PADDING = None +_CN.DATASET.VALID.DATA_ROOT = None +_CN.DATASET.VALID.NPZ_ROOT = None +_CN.DATASET.VALID.MAX_SAMPLES = None +_CN.DATASET.VALID.MIN_OVERLAP_SCORE = None +_CN.DATASET.VALID.MAX_OVERLAP_SCORE = None +_CN.DATASET.VALID.AUGMENTATION_TYPE = None +_CN.DATASET.VALID.LIST_PATH = None + +# TESTS +_CN.DATASET.TESTS = CN() +_CN.DATASET.TESTS.PADDING = False +_CN.DATASET.TESTS.DATA_ROOT = DATA_ROOT +_CN.DATASET.TESTS.NPZ_ROOT = NPZ_ROOT +_CN.DATASET.TESTS.MAX_SAMPLES = 13 +_CN.DATASET.TESTS.MIN_OVERLAP_SCORE = 0.0 +_CN.DATASET.TESTS.MAX_OVERLAP_SCORE = 0.5 +_CN.DATASET.TESTS.AUGMENTATION_TYPE = None +_CN.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/GL3D.txt' + +cfg = _CN diff --git a/imcui/third_party/gim/datasets/gl3d/gl3d.py b/imcui/third_party/gim/datasets/gl3d/gl3d.py new file mode 100644 index 0000000000000000000000000000000000000000..df06a509dbfe524e415d40f5fc400dd4cbd13b4a --- /dev/null +++ b/imcui/third_party/gim/datasets/gl3d/gl3d.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import glob +import torch +import imagesize +import torch.nn.functional as F + + +from os.path import join + +from torch.utils.data import Dataset + +from datasets.utils import read_images + + +class GL3DDataset(Dataset): + def __init__(self, + root_dir, # data root dit + npz_root, # data info, like, overlap, image_path, depth_path + seq_name, # current sequence + mode, # train or val or test + min_overlap_score, + max_overlap_score, + max_resize, # max edge after resize + df, # general is 8 for ResNet w/ pre 3-layers + padding, # padding image for batch training + augment_fn, # augmentation function + max_samples, # max sample in current sequence + **kwargs): + super().__init__() + + self.root = join('zeb', seq_name) + + paths = glob.glob(join(self.root, '*.txt')) + + lines = [] + for path in paths: + with open(path, 'r') as file: + scene_id = path.rpartition('/')[-1].rpartition('.')[0].split('_')[0] + line = file.readline().strip().split() + lines.append([scene_id] + line) + + self.pairs = sorted(lines) + + self.df = df + self.max_resize = max_resize + self.padding = padding + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + pair = self.pairs[idx] + + scene_id = pair[0] + + img_name0 = pair[1].rpartition('.')[0] + img_name1 = pair[2].rpartition('.')[0] + + img_path0 = join(self.root, '{}_{}.png'.format(scene_id, img_name0)) + img_path1 = join(self.root, '{}_{}.png'.format(scene_id, img_name1)) + + width0, height0 = imagesize.get(img_path0) + width1, height1 = imagesize.get(img_path1) + + image0, color0, scale0, resize0, mask0 = read_images( + img_path0, self.max_resize, self.df, self.padding, None) + image1, color1, scale1, resize1, mask1 = read_images( + img_path1, self.max_resize, self.df, self.padding, None) + + K0 = torch.tensor(list(map(float, pair[5:14])), dtype=torch.float).reshape(3, 3) + K1 = torch.tensor(list(map(float, pair[14:23])), dtype=torch.float).reshape(3, 3) + + # read image size + imsize0 = torch.tensor([height0, width0], dtype=torch.long) + imsize1 = torch.tensor([height1, width1], dtype=torch.long) + resize0 = torch.tensor(resize0, dtype=torch.long) + resize1 = torch.tensor(resize1, dtype=torch.long) + + T_0to1 = torch.tensor(list(map(float, pair[23:])), dtype=torch.float).reshape(4, 4) + + data = { + # image 0 + 'image0': image0, # (1, 3, h, w) + 'color0': color0, # (1, h, w) + 'imsize0': imsize0, # (2) - 2:(h, w) + 'resize0': resize0, # (2) - 2:(h, w) + + # image 1 + 'image1': image1, + 'color1': color1, + 'imsize1': imsize1, # (2) - 2:[h, w] + 'resize1': resize1, # (2) - 2:(h, w) + + # image transform + 'T_0to1': T_0to1, # (4, 4) + 'K0': K0, # (3, 3) + 'K1': K1, + # pair information + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'GL3D', + 'scene_id': scene_id, + 'pair_id': f'{idx}-{idx}', + 'pair_names': (img_name0, + img_name1), + 'covisible0': float(pair[3]), + 'covisible1': float(pair[4]), + } + + if mask0 is not None: # img_padding is True + if self.scale: + # noinspection PyArgumentList + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + # noinspection PyUnboundLocalVariable + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/gim/datasets/gl3d/utils.py b/imcui/third_party/gim/datasets/gl3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e02b12ed2a5b740059efe8bab3ef212103acce86 --- /dev/null +++ b/imcui/third_party/gim/datasets/gl3d/utils.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python +""" +Copyright 2017, Zixin Luo, HKUST. +IO tools. +""" + +from __future__ import print_function + +import os +import re +import cv2 +import numpy as np + +from struct import unpack + + +def get_pose(R, t): + T = np.zeros((4, 4), dtype=R.dtype) + T[:3,:3] = R + T[:3,3:] = t + T[ 3, 3] = 1 + return T + + +def load_pfm(pfm_path): + with open(pfm_path, 'rb') as fin: + color = None + width = None + height = None + scale = None + data_type = None + header = str(fin.readline().decode('UTF-8')).rstrip() + + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8')) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + scale = float((fin.readline().decode('UTF-8')).rstrip()) + if scale < 0: # little-endian + data_type = '= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size[0], pad_size[1]), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + elif inp.ndim == 3: + padded = np.zeros((pad_size[0], pad_size[1], inp.shape[-1]), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + else: + raise NotImplementedError() + + if ret_mask: + mask = np.zeros((pad_size[0], pad_size[1]), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + + return padded, mask + + +def read_depth(path): + # loads depth map D from png file + # and returns it as a numpy array, + # for details see readme.txt + + depth_png = np.array(Image.open(path), dtype=int) + # make sure we have a proper 16bit depth map here.. not 8bit! + assert(np.max(depth_png) > 255) + + depth = depth_png.astype(float) / 256. + depth[depth_png == 0] = -1. + + padded = np.zeros((400, 1300), dtype=depth.dtype) + padded[:depth.shape[0], :depth.shape[1]] = depth + + return padded + + +def read_images(path, max_resize, df, padding, augment_fn=None, image=None): + """ + Args: + path: string + max_resize (int): max image size after resied + df (int, optional): image size division factor. + NOTE: this will change the final image size after img_resize + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + image: RGB image + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + assert max_resize is not None + + image = imread_color(path, augment_fn) if image is None else image # (w,h,3) image is RGB + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + # resize image + w, h = image.shape[1], image.shape[0] + if max(w, h) > max_resize: + w_new, h_new = get_resized_wh(w, h, max_resize) # make max(w, h) to max_size + else: + w_new, h_new = w, h + + # w_new, h_new = get_divisible_wh(w_new, h_new, df) # make image divided by df and must <= max_size + image = cv2.resize(image, (w_new, h_new)) # (w',h',3) + gray = cv2.resize(gray, (w_new, h_new)) # (w',h',3) + scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) + + # padding + mask = None + if padding: + image, _ = pad_bottom_right(image, (int(max_resize/3.25), max_resize), ret_mask=False) + gray, mask = pad_bottom_right(gray, (int(max_resize/3.25), max_resize), ret_mask=True) + mask = torch.from_numpy(mask) + + gray = torch.from_numpy(gray).float()[None] / 255 # (1,h,w) + image = torch.from_numpy(image).float() / 255 # (h,w,3) + image = image.permute(2,0,1) # (3,h,w) + + resize = [h_new, w_new] + + return gray, image, scale, resize, mask diff --git a/imcui/third_party/gim/datasets/multifov/__init__.py b/imcui/third_party/gim/datasets/multifov/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2460e8138141ede3096755da8d6eef184a86a851 --- /dev/null +++ b/imcui/third_party/gim/datasets/multifov/__init__.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from os.path import join +from yacs.config import CfgNode as CN + +########################################## +#++++++++++++++++++++++++++++++++++++++++# +#+ +# +#+ Multi-FoV +# +#+ +# +#++++++++++++++++++++++++++++++++++++++++# +########################################## + +_CN = CN() + +_CN.DATASET = CN() + +DATA_ROOT = 'data/Multi-FoV/' +NPZ_ROOT = DATA_ROOT + +_CN.NJOBS = 1 + +# TRAIN +_CN.DATASET.TRAIN = CN() +_CN.DATASET.TRAIN.PADDING = None +_CN.DATASET.TRAIN.DATA_ROOT = None +_CN.DATASET.TRAIN.NPZ_ROOT = None +_CN.DATASET.TRAIN.MAX_SAMPLES = None +_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.MAX_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.AUGMENTATION_TYPE = None +_CN.DATASET.TRAIN.LIST_PATH = None + +# VALID +_CN.DATASET.VALID = CN() +_CN.DATASET.VALID.PADDING = None +_CN.DATASET.VALID.DATA_ROOT = None +_CN.DATASET.VALID.NPZ_ROOT = None +_CN.DATASET.VALID.MAX_SAMPLES = None +_CN.DATASET.VALID.MIN_OVERLAP_SCORE = None +_CN.DATASET.VALID.MAX_OVERLAP_SCORE = None +_CN.DATASET.VALID.AUGMENTATION_TYPE = None +_CN.DATASET.VALID.LIST_PATH = None + +# TESTS +_CN.DATASET.TESTS = CN() +_CN.DATASET.TESTS.PADDING = False +_CN.DATASET.TESTS.DATA_ROOT = DATA_ROOT +_CN.DATASET.TESTS.NPZ_ROOT = NPZ_ROOT +_CN.DATASET.TESTS.MAX_SAMPLES = 5000 +_CN.DATASET.TESTS.MIN_OVERLAP_SCORE = 0.0 +_CN.DATASET.TESTS.MAX_OVERLAP_SCORE = 0.5 +_CN.DATASET.TESTS.AUGMENTATION_TYPE = None +_CN.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/Multi-FoV.txt' + +cfg = _CN diff --git a/imcui/third_party/gim/datasets/multifov/multifov.py b/imcui/third_party/gim/datasets/multifov/multifov.py new file mode 100644 index 0000000000000000000000000000000000000000..773bf690b8c815614f786815a1b9410bd58d8ece --- /dev/null +++ b/imcui/third_party/gim/datasets/multifov/multifov.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import glob +import torch +import imagesize +import torch.nn.functional as F + + +from os.path import join + +from torch.utils.data import Dataset + +from datasets.utils import read_images + + +class MultiFoVDataset(Dataset): + def __init__(self, + root_dir, # data root dit + npz_root, # data info, like, overlap, image_path, depth_path + seq_name, # current sequence + mode, # train or val or test + min_overlap_score, + max_overlap_score, + max_resize, # max edge after resize + df, # general is 8 for ResNet w/ pre 3-layers + padding, # padding image for batch training + augment_fn, # augmentation function + max_samples, # max sample in current sequence + **kwargs): + super().__init__() + + self.root = join('zeb', seq_name) + + paths = glob.glob(join(self.root, '*.txt')) + + lines = [] + for path in paths: + with open(path, 'r') as file: + scene_id = path.rpartition('/')[-1].rpartition('.')[0].split('-')[0] + line = file.readline().strip().split() + lines.append([scene_id] + line) + + self.pairs = sorted(lines) + + self.scale = 1 / df + + self.df = df + self.max_resize = max_resize + self.padding = padding + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + pair = self.pairs[idx] + + scene_id = pair[0] + + img_name0 = pair[1] + img_name1 = pair[2] + + img_path0 = join(self.root, '{}-{}.png'.format(scene_id, img_name0)) + img_path1 = join(self.root, '{}-{}.png'.format(scene_id, img_name1)) + + width0, height0 = imagesize.get(img_path0) + width1, height1 = imagesize.get(img_path1) + + image0, color0, scale0, resize0, mask0 = read_images( + img_path0, self.max_resize, self.df, self.padding, None) + image1, color1, scale1, resize1, mask1 = read_images( + img_path1, self.max_resize, self.df, self.padding, None) + + K0 = torch.tensor(list(map(float, pair[5:14])), dtype=torch.float).reshape(3, 3) + K1 = torch.tensor(list(map(float, pair[14:23])), dtype=torch.float).reshape(3, 3) + + # read image size + imsize0 = torch.tensor([height0, width0], dtype=torch.long) + imsize1 = torch.tensor([height1, width1], dtype=torch.long) + resize0 = torch.tensor(resize0, dtype=torch.long) + resize1 = torch.tensor(resize1, dtype=torch.long) + + # read and compute relative poses + T_0to1 = torch.tensor(list(map(float, pair[23:])), dtype=torch.float).reshape(4, 4) + + data = { + # image 0 + 'image0': image0, # (1, 3, h, w) + 'color0': color0, # (1, h, w) + 'imsize0': imsize0, # (2) - 2:(h, w) + 'resize0': resize0, # (2) - 2:(h, w) + + # image 1 + 'image1': image1, + 'color1': color1, + 'imsize1': imsize1, # (2) - 2:[h, w] + 'resize1': resize1, # (2) - 2:(h, w) + + # image transform + 'T_0to1': T_0to1, # (4, 4) + 'K0': K0, # (3, 3) + 'K1': K1, + # pair information + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'MultiFoV', + 'scene_id': scene_id, + 'pair_id': f'{idx}-{idx}', + 'pair_names': (f'img/{img_name0}.png', + f'img/{img_name1}.png'), + 'covisible0': float(pair[3]), + 'covisible1': float(pair[4]), + } + + if mask0 is not None: # img_padding is True + if self.scale: + # noinspection PyArgumentList + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + # noinspection PyUnboundLocalVariable + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/gim/datasets/multifov/utils.py b/imcui/third_party/gim/datasets/multifov/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c28e3f956efd62fe226a081d738a9fe1f3191dd4 --- /dev/null +++ b/imcui/third_party/gim/datasets/multifov/utils.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import numpy as np + + +def convert(xyzw): + x, y, z, w = xyzw + R = np.array([ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ] + ]) + return R diff --git a/imcui/third_party/gim/datasets/robotcar/__init__.py b/imcui/third_party/gim/datasets/robotcar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaeef715404c46da84478186192425d70514f224 --- /dev/null +++ b/imcui/third_party/gim/datasets/robotcar/__init__.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from os.path import join +from yacs.config import CfgNode as CN + +########################################## +#++++++++++++++++++++++++++++++++++++++++# +#+ +# +#+ ROBOTCAR +# +#+ +# +#++++++++++++++++++++++++++++++++++++++++# +########################################## + +_CN = CN() + +_CN.DATASET = CN() + +DATA_ROOT = 'data/Robotcar/' +NPZ_ROOT = DATA_ROOT + +_CN.NJOBS = 1 + +# TRAIN +_CN.DATASET.TRAIN = CN() +_CN.DATASET.TRAIN.PADDING = None +_CN.DATASET.TRAIN.DATA_ROOT = None +_CN.DATASET.TRAIN.NPZ_ROOT = None +_CN.DATASET.TRAIN.MAX_SAMPLES = None +_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.MAX_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.AUGMENTATION_TYPE = None +_CN.DATASET.TRAIN.LIST_PATH = None + +# VALID +_CN.DATASET.VALID = CN() +_CN.DATASET.VALID.PADDING = None +_CN.DATASET.VALID.DATA_ROOT = None +_CN.DATASET.VALID.NPZ_ROOT = None +_CN.DATASET.VALID.MAX_SAMPLES = None +_CN.DATASET.VALID.MIN_OVERLAP_SCORE = None +_CN.DATASET.VALID.MAX_OVERLAP_SCORE = None +_CN.DATASET.VALID.AUGMENTATION_TYPE = None +_CN.DATASET.VALID.LIST_PATH = None + +# TESTS +_CN.DATASET.TESTS = CN() +_CN.DATASET.TESTS.PADDING = False +_CN.DATASET.TESTS.DATA_ROOT = DATA_ROOT +_CN.DATASET.TESTS.NPZ_ROOT = NPZ_ROOT +_CN.DATASET.TESTS.MAX_SAMPLES = 500 +_CN.DATASET.TESTS.MIN_OVERLAP_SCORE = 0.0 +_CN.DATASET.TESTS.MAX_OVERLAP_SCORE = 0.5 +_CN.DATASET.TESTS.AUGMENTATION_TYPE = None +_CN.DATASET.TESTS.LIST_PATH = None + +weather = _CN.clone() +weather.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/RobotcarWeather.txt' + +season = _CN.clone() +season.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/RobotcarSeason.txt' + +night = _CN.clone() +night.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/RobotcarNight.txt' diff --git a/imcui/third_party/gim/datasets/robotcar/robotcar.py b/imcui/third_party/gim/datasets/robotcar/robotcar.py new file mode 100644 index 0000000000000000000000000000000000000000..7d08a317cc3337d1272f87e99cf13e9f47dfd4bf --- /dev/null +++ b/imcui/third_party/gim/datasets/robotcar/robotcar.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import glob +import torch +import imagesize +import torch.nn.functional as F + + +from os.path import join + +from torch.utils.data import Dataset + +from datasets.utils import read_images + + +class RobotcarDataset(Dataset): + def __init__(self, + root_dir, # data root dit + npz_root, # data info, like, overlap, image_path, depth_path + seq_name, # current sequence + mode, # train or val or test + min_overlap_score, + max_overlap_score, + max_resize, # max edge after resize + df, # general is 8 for ResNet w/ pre 3-layers + padding, # padding image for batch training + augment_fn, # augmentation function + max_samples, # max sample in current sequence + **kwargs): + super().__init__() + + self.root = join('zeb', seq_name) + + paths = glob.glob(join(self.root, '*.txt')) + + lines = [] + for path in paths: + with open(path, 'r') as file: + scene_id = path.rpartition('/')[-1].rpartition('.')[0].split('_')[0] + line = file.readline().strip().split() + lines.append([scene_id] + line) + + self.pairs = sorted(lines) + + self.scale = 1 / df + + self.df = df + self.max_resize = max_resize + self.padding = padding + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + pair = self.pairs[idx] + + scene_id = pair[0] + + timestamp0 = pair[1] + timestamp1 = pair[2] + + img_path0 = join(self.root, '{}_{}.png'.format(scene_id, timestamp0)) + img_path1 = join(self.root, '{}_{}.png'.format(scene_id, timestamp1)) + + width0, height0 = imagesize.get(img_path0) + width1, height1 = imagesize.get(img_path1) + + image0, color0, scale0, resize0, mask0 = read_images( + img_path0, self.max_resize, self.df, self.padding, None) + image1, color1, scale1, resize1, mask1 = read_images( + img_path1, self.max_resize, self.df, self.padding, None) + + K0 = torch.tensor(list(map(float, pair[5:14])), dtype=torch.float).reshape(3, 3) + K1 = torch.tensor(list(map(float, pair[14:23])), dtype=torch.float).reshape(3, 3) + + # read image size + imsize0 = torch.tensor([height0, width0], dtype=torch.long) + imsize1 = torch.tensor([height1, width1], dtype=torch.long) + resize0 = torch.tensor(resize0, dtype=torch.long) + resize1 = torch.tensor(resize1, dtype=torch.long) + + T_0to1 = torch.tensor(list(map(float, pair[23:])), dtype=torch.float).reshape(4, 4) + + data = { + # image 0 + 'image0': image0, # (1, 3, h, w) + 'color0': color0, # (1, h, w) + 'imsize0': imsize0, # (2) - 2:(h, w) + 'resize0': resize0, # (2) - 2:(h, w) + + # image 1 + 'image1': image1, + 'color1': color1, + 'imsize1': imsize1, # (2) - 2:[h, w] + 'resize1': resize1, # (2) - 2:(h, w) + + # image transform + 'T_0to1': T_0to1, # (4, 4) + 'K0': K0, # (3, 3) + 'K1': K1, + # pair information + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'Robotcar', + 'scene_id': scene_id, + 'pair_id': f'{idx}-{idx}', + 'pair_names': (str(timestamp0), + str(timestamp1)), + 'covisible0': float(pair[3]), + 'covisible1': float(pair[4]), + } + + if mask0 is not None: # img_padding is True + if self.scale: + # noinspection PyArgumentList + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + # noinspection PyUnboundLocalVariable + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/gim/datasets/scenenet/__init__.py b/imcui/third_party/gim/datasets/scenenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf131412286907996bce3a2bb151a515adbd6586 --- /dev/null +++ b/imcui/third_party/gim/datasets/scenenet/__init__.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from os.path import join +from yacs.config import CfgNode as CN + +########################################## +#++++++++++++++++++++++++++++++++++++++++# +#+ +# +#+ SceneNet-RGBD +# +#+ +# +#++++++++++++++++++++++++++++++++++++++++# +########################################## + +_CN = CN() + +_CN.DATASET = CN() + +DATA_ROOT = 'data/SceneNetRGBD/' +NPZ_ROOT = DATA_ROOT + +_CN.NJOBS = 1 + +# TRAIN +_CN.DATASET.TRAIN = CN() +_CN.DATASET.TRAIN.PADDING = None +_CN.DATASET.TRAIN.DATA_ROOT = None +_CN.DATASET.TRAIN.NPZ_ROOT = None +_CN.DATASET.TRAIN.MAX_SAMPLES = None +_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.MAX_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.AUGMENTATION_TYPE = None +_CN.DATASET.TRAIN.LIST_PATH = None + +# VALID +_CN.DATASET.VALID = CN() +_CN.DATASET.VALID.PADDING = None +_CN.DATASET.VALID.DATA_ROOT = None +_CN.DATASET.VALID.NPZ_ROOT = None +_CN.DATASET.VALID.MAX_SAMPLES = None +_CN.DATASET.VALID.MIN_OVERLAP_SCORE = None +_CN.DATASET.VALID.MAX_OVERLAP_SCORE = None +_CN.DATASET.VALID.AUGMENTATION_TYPE = None +_CN.DATASET.VALID.LIST_PATH = None + +# TESTS +_CN.DATASET.TESTS = CN() +_CN.DATASET.TESTS.PADDING = False +_CN.DATASET.TESTS.DATA_ROOT = join(DATA_ROOT, 'test') +_CN.DATASET.TESTS.NPZ_ROOT = NPZ_ROOT +_CN.DATASET.TESTS.MAX_SAMPLES = 30 +_CN.DATASET.TESTS.MIN_OVERLAP_SCORE = 0.0 +_CN.DATASET.TESTS.MAX_OVERLAP_SCORE = 0.5 +_CN.DATASET.TESTS.AUGMENTATION_TYPE = None +_CN.DATASET.TESTS.LIST_PATH = 'datasets/_tests_/SceneNetRGBD.txt' + +cfg = _CN diff --git a/imcui/third_party/gim/datasets/scenenet/scenenet.py b/imcui/third_party/gim/datasets/scenenet/scenenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e58052e6c5cdb280799c5f3b732b5430c4189aa1 --- /dev/null +++ b/imcui/third_party/gim/datasets/scenenet/scenenet.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import glob +import torch +import imagesize +import torch.nn.functional as F + + +from os.path import join + +from torch.utils.data import Dataset + +from datasets.utils import read_images + + +class SceneNetDataset(Dataset): + def __init__(self, + root_dir, # data root dit + npz_root, # data info, like, overlap, image_path, depth_path + seq_name, # current sequence + mode, # train or val or test + min_overlap_score, + max_overlap_score, + max_resize, # max edge after resize + df, # general is 8 for ResNet w/ pre 3-layers + padding, # padding image for batch training + augment_fn, # augmentation function + max_samples, # max sample in current sequence + **kwargs): + super().__init__() + + self.root = join('zeb', seq_name) + + paths = glob.glob(join(self.root, '*.txt')) + + lines = [] + for path in paths: + with open(path, 'r') as file: + scene_id = path.rpartition('/')[-1].rpartition('.')[0].split('-')[0] + line = file.readline().strip().split() + lines.append([scene_id] + line) + + self.pairs = sorted(lines) + + self.scale = 1 / df + + self.df = df + self.max_resize = max_resize + self.padding = padding + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + pair = self.pairs[idx] + + scene_id = pair[0] + + img_name0 = pair[1] + img_name1 = pair[2] + + img_path0 = join(self.root, '{}-{}.png'.format(scene_id, img_name0)) + img_path1 = join(self.root, '{}-{}.png'.format(scene_id, img_name1)) + + width0, height0 = imagesize.get(img_path0) + width1, height1 = imagesize.get(img_path1) + + image0, color0, scale0, resize0, mask0 = read_images( + img_path0, self.max_resize, self.df, self.padding, None) + image1, color1, scale1, resize1, mask1 = read_images( + img_path1, self.max_resize, self.df, self.padding, None) + + K0 = torch.tensor(list(map(float, pair[5:14])), dtype=torch.float).reshape(3, 3) + K1 = torch.tensor(list(map(float, pair[14:23])), dtype=torch.float).reshape(3, 3) + + # read image size + imsize0 = torch.tensor([height0, width0], dtype=torch.long) + imsize1 = torch.tensor([height1, width1], dtype=torch.long) + resize0 = torch.tensor(resize0, dtype=torch.long) + resize1 = torch.tensor(resize1, dtype=torch.long) + + # read and compute relative poses + T_0to1 = torch.tensor(list(map(float, pair[23:])), dtype=torch.float).reshape(4, 4) + + data = { + # image 0 + 'image0': image0, # (1, 3, h, w) + 'color0': color0, # (1, h, w) + 'imsize0': imsize0, # (2) - 2:(h, w) + 'resize0': resize0, # (2) - 2:(h, w) + + # image 1 + 'image1': image1, + 'color1': color1, + 'imsize1': imsize1, # (2) - 2:[h, w] + 'resize1': resize1, # (2) - 2:(h, w) + + # image transform + 'T_0to1': T_0to1, # (4, 4) + 'K0': K0, # (3, 3) + 'K1': K1, + # pair information + 'scale0': scale0, # [scale_w, scale_h] + 'scale1': scale1, + 'dataset_name': 'SceneNet', + 'scene_id': scene_id, + 'pair_id': f'{idx}-{idx}', + 'pair_names': (img_name0+'.jpg', + img_name1+'.jpg'), + 'covisible0': float(pair[3]), + 'covisible1': float(pair[4]), + } + + if mask0 is not None: # img_padding is True + if self.scale: + # noinspection PyArgumentList + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + # noinspection PyUnboundLocalVariable + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + + return data diff --git a/imcui/third_party/gim/datasets/scenenet/utils.py b/imcui/third_party/gim/datasets/scenenet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a32b27fb961a6181128ee2c6aeb0c73bdeb83ed --- /dev/null +++ b/imcui/third_party/gim/datasets/scenenet/utils.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import math + +import numpy as np + +from imageio import imread + +import datasets.scenenet.scenenet_pb2 as sn + + +def camera_intrinsic_transform(vfov=45,hfov=60,pixel_width=320,pixel_height=240): + camera_intrinsics = np.zeros((3,4)) + camera_intrinsics[2,2] = 1 + camera_intrinsics[0,0] = (pixel_width/2.0)/math.tan(math.radians(hfov/2.0)) + camera_intrinsics[0,2] = pixel_width/2.0 + camera_intrinsics[1,1] = (pixel_height/2.0)/math.tan(math.radians(vfov/2.0)) + camera_intrinsics[1,2] = pixel_height/2.0 + return camera_intrinsics + + +def read_depth(filename): + depth = np.array(imread(filename)) + depth = depth.astype(np.float32) / 1000 + return depth + + +def position_to_np_array(position,homogenous=False): + if not homogenous: + return np.array([position.x,position.y,position.z]) + return np.array([position.x,position.y,position.z,1.0]) + + +def interpolate_poses(start_pose,end_pose,alpha): + assert alpha >= 0.0 + assert alpha <= 1.0 + camera_pose = alpha * position_to_np_array(end_pose.camera) + camera_pose += (1.0 - alpha) * position_to_np_array(start_pose.camera) + lookat_pose = alpha * position_to_np_array(end_pose.lookat) + lookat_pose += (1.0 - alpha) * position_to_np_array(start_pose.lookat) + timestamp = alpha * end_pose.timestamp + (1.0 - alpha) * start_pose.timestamp + pose = sn.Pose() + pose.camera.x = camera_pose[0] + pose.camera.y = camera_pose[1] + pose.camera.z = camera_pose[2] + pose.lookat.x = lookat_pose[0] + pose.lookat.y = lookat_pose[1] + pose.lookat.z = lookat_pose[2] + pose.timestamp = timestamp + return pose + + +def normalize(v): + return v/np.linalg.norm(v) + + +def world_to_camera_with_pose(view_pose): + lookat_pose = position_to_np_array(view_pose.lookat) + camera_pose = position_to_np_array(view_pose.camera) + up = np.array([0,1,0]) + R = np.diag(np.ones(4)) + R[2,:3] = normalize(lookat_pose - camera_pose) + R[0,:3] = normalize(np.cross(R[2,:3],up)) + R[1,:3] = -normalize(np.cross(R[0,:3],R[2,:3])) + T = np.diag(np.ones(4)) + T[:3,3] = -camera_pose + return R.dot(T) diff --git a/imcui/third_party/gim/datasets/utils.py b/imcui/third_party/gim/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0b1366777cc3316292986043fb11c0a67ed56d --- /dev/null +++ b/imcui/third_party/gim/datasets/utils.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import cv2 +import torch +import numpy as np + + +# ------------ +# DATA TOOLS +# ------------ +def imread_gray(path, augment_fn=None): + if augment_fn is None: + image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) + else: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (h, w) + + +def imread_color(path, augment_fn=None): + if augment_fn is None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + else: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + return image # (h, w) + + +def get_resized_wh(w, h, resize=None): + if resize is not None: # resize the longer edge + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new = max((w // df), 1) * df + h_new = max((h // df), 1) * df + # resize = int(max(max(w, h) // df, 1) * df) + # w_new, h_new = get_resized_wh(w, h, resize) + # scale = resize / x + # w_new, h_new = map(lambda x: int(max(x // df, 1) * df), [w, h]) + else: + w_new, h_new = w, h + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + elif inp.ndim == 3: + padded = np.zeros((pad_size, pad_size, inp.shape[-1]), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + else: + raise NotImplementedError() + + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + + return padded, mask + + +def split(n, k): + d, r = divmod(n, k) + return [d + 1] * r + [d] * (k - r) + + +def read_images(path, max_resize, df, padding, augment_fn=None, image=None): + """ + Args: + path: string + max_resize (int): max image size after resied + df (int, optional): image size division factor. + NOTE: this will change the final image size after img_resize + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + image: RGB image + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + assert max_resize is not None + + image = imread_color(path, augment_fn) if image is None else image # (w,h,3) image is RGB + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + # resize image + w, h = image.shape[1], image.shape[0] + if max(w, h) > max_resize: + w_new, h_new = get_resized_wh(w, h, max_resize) # make max(w, h) to max_size + else: + w_new, h_new = w, h + + w_new, h_new = get_divisible_wh(w_new, h_new, df) # make image divided by df and must <= max_size + image = cv2.resize(image, (w_new, h_new)) # (w',h',3) + gray = cv2.resize(gray, (w_new, h_new)) # (w',h',3) + scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) + + # padding + mask = None + if padding: + image, _ = pad_bottom_right(image, max_resize, ret_mask=False) + gray, mask = pad_bottom_right(gray, max_resize, ret_mask=True) + mask = torch.from_numpy(mask) + + gray = torch.from_numpy(gray).float()[None] / 255 # (1,h,w) + image = torch.from_numpy(image).float() / 255 # (h,w,3) + image = image.permute(2,0,1) # (3,h,w) + + resize = [h_new, w_new] + + return gray, image, scale, resize, mask diff --git a/imcui/third_party/gim/datasets/walk/__init__.py b/imcui/third_party/gim/datasets/walk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcb35f7f1038d59807e518723d00e9cd9a58879 --- /dev/null +++ b/imcui/third_party/gim/datasets/walk/__init__.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from os.path import join +from yacs.config import CfgNode as CN + +########################################## +#++++++++++++++++++++++++++++++++++++++++# +#+ +# +#+ WALK +# +#+ +# +#++++++++++++++++++++++++++++++++++++++++# +########################################## + +_CN = CN() + +_CN.DATASET = CN() + +DATA_ROOT = join('data', 'ZeroMatch') +NPZ_ROOT = join(DATA_ROOT, 'pseudo') + +_CN.NJOBS = 1 # x scenes + +# TRAIN +_CN.DATASET.TRAIN = CN() +_CN.DATASET.TRAIN.PADDING = True +_CN.DATASET.TRAIN.DATA_ROOT = join(DATA_ROOT, 'video_1080p') +_CN.DATASET.TRAIN.NPZ_ROOT = NPZ_ROOT +_CN.DATASET.TRAIN.MAX_SAMPLES = -1 +_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.MAX_OVERLAP_SCORE = None +_CN.DATASET.TRAIN.AUGMENTATION_TYPE = 'dark' +_CN.DATASET.TRAIN.LIST_PATH = 'datasets/_train_/100h.txt' + +# OTHERS +_CN.DATASET.TRAIN.STEP = 1000 +_CN.DATASET.TRAIN.PIX_THR = 1 +_CN.DATASET.TRAIN.MAX_CANDIDATE_MATCHES = -1 +_CN.DATASET.TRAIN.MIN_FINAL_MATCHES = 512 +_CN.DATASET.TRAIN.MIN_FILTER_MATCHES = 32 +_CN.DATASET.TRAIN.FIX_MATCHES = 100000 +_CN.DATASET.TRAIN.SOURCE_ROOT = join(DATA_ROOT, 'video_1080p') +_CN.DATASET.TRAIN.PROPAGATE_ROOT = join(DATA_ROOT, 'propagate') +_CN.DATASET.TRAIN.VIDEO_IMAGE_ROOT = join(DATA_ROOT, 'image_1080p') +_CN.DATASET.TRAIN.PSEUDO_LABELS = [ + 'WALK SIFT [R] F [S] 10', + 'WALK SIFT [R] F [S] 20', + 'WALK SIFT [R] F [S] 40', + 'WALK SIFT [R] F [S] 80', + 'WALK SIFT [R] T [S] 10', + 'WALK SIFT [R] T [S] 20', + 'WALK SIFT [R] T [S] 40', + 'WALK SIFT [R] T [S] 80', + + 'WALK GIM_DKM [R] F [S] 10', + 'WALK GIM_DKM [R] F [S] 20', + 'WALK GIM_DKM [R] F [S] 40', + 'WALK GIM_DKM [R] F [S] 80', + 'WALK GIM_DKM [R] T [S] 10', + 'WALK GIM_DKM [R] T [S] 20', + 'WALK GIM_DKM [R] T [S] 40', + 'WALK GIM_DKM [R] T [S] 80', + + 'WALK GIM_GLUE [R] F [S] 10', + 'WALK GIM_GLUE [R] F [S] 20', + 'WALK GIM_GLUE [R] F [S] 40', + 'WALK GIM_GLUE [R] F [S] 80', + 'WALK GIM_GLUE [R] T [S] 10', + 'WALK GIM_GLUE [R] T [S] 20', + 'WALK GIM_GLUE [R] T [S] 40', + 'WALK GIM_GLUE [R] T [S] 80', + + 'WALK GIM_LOFTR [R] F [S] 10', + 'WALK GIM_LOFTR [R] F [S] 20', + 'WALK GIM_LOFTR [R] F [S] 40', + 'WALK GIM_LOFTR [R] F [S] 80', + 'WALK GIM_LOFTR [R] T [S] 10', + 'WALK GIM_LOFTR [R] T [S] 20', + 'WALK GIM_LOFTR [R] T [S] 40', + 'WALK GIM_LOFTR [R] T [S] 80', +] + +# VALID +_CN.DATASET.VALID = CN() +_CN.DATASET.VALID.PADDING = None +_CN.DATASET.VALID.DATA_ROOT = None +_CN.DATASET.VALID.NPZ_ROOT = None +_CN.DATASET.VALID.MAX_SAMPLES = None +_CN.DATASET.VALID.MIN_OVERLAP_SCORE = None +_CN.DATASET.VALID.MAX_OVERLAP_SCORE = None +_CN.DATASET.VALID.AUGMENTATION_TYPE = None +_CN.DATASET.VALID.LIST_PATH = None + +# TESTS +_CN.DATASET.TESTS = CN() +_CN.DATASET.TESTS.PADDING = None +_CN.DATASET.TESTS.DATA_ROOT = None +_CN.DATASET.TESTS.NPZ_ROOT = None +_CN.DATASET.TESTS.MAX_SAMPLES = None +_CN.DATASET.TESTS.MIN_OVERLAP_SCORE = None +_CN.DATASET.TESTS.MAX_OVERLAP_SCORE = None +_CN.DATASET.TESTS.AUGMENTATION_TYPE = None +_CN.DATASET.TESTS.LIST_PATH = None + +cfg = _CN diff --git a/imcui/third_party/gim/datasets/walk/propagate.py b/imcui/third_party/gim/datasets/walk/propagate.py new file mode 100644 index 0000000000000000000000000000000000000000..31cd42d187ea8a89a808a4f39723d74a4526fb52 --- /dev/null +++ b/imcui/third_party/gim/datasets/walk/propagate.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import os +from tqdm import tqdm +from argparse import ArgumentParser +from torch.utils.data import DataLoader + +from datasets.walk import cfg +from datasets.walk.walk import WALKDataset + + +def propagate(loader, seq): + for i, _ in enumerate(tqdm( + loader, ncols=80, bar_format="{l_bar}{bar:3}{r_bar}", total=len(loader), + desc=f'[ {seq[:min(10, len(seq)-1)]:<10} ] [ {len(loader):<5} ]')): + continue + + +def init_dataset(seq_name_): + train_cfg = cfg.DATASET.TRAIN + + base_input = { + 'df': 8, + 'mode': 'train', + 'augment_fn': None, + 'PROPAGATING': True, + 'seq_name': seq_name_, + 'max_resize': [1280, 720], + 'padding': cfg.DATASET.TRAIN.PADDING, + 'max_samples': cfg.DATASET.TRAIN.MAX_SAMPLES, + 'min_overlap_score': cfg.DATASET.TRAIN.MIN_OVERLAP_SCORE, + 'max_overlap_score': cfg.DATASET.TRAIN.MAX_OVERLAP_SCORE + } + + cfg_input = { + k: getattr(train_cfg, k) + for k in [ + 'DATA_ROOT', 'NPZ_ROOT', 'STEP', 'PIX_THR', 'FIX_MATCHES', 'SOURCE_ROOT', + 'MAX_CANDIDATE_MATCHES', 'MIN_FINAL_MATCHES', 'MIN_FILTER_MATCHES', + 'VIDEO_IMAGE_ROOT', 'PROPAGATE_ROOT', 'PSEUDO_LABELS' + ] + } + + # 合并配置 + input_ = { + **base_input, + **cfg_input, + 'root_dir': cfg_input['DATA_ROOT'], + 'npz_root': cfg_input['NPZ_ROOT'] + } + + dataset = WALKDataset(**input_) + + return dataset + + +# noinspection PyUnusedLocal +def collate_fn(batch): + return None + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('seq_names', type=str, nargs='+') + args = parser.parse_args() + + if os.path.isfile(args.seq_names[0]): + with open(args.seq_names[0], 'r') as f: + seq_names = [line.strip() for line in f.readlines()] + else: + seq_names = args.seq_names + + for seq_name in seq_names: + + dataset_ = init_dataset(seq_name) + + loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 3, + 'pin_memory': True, 'drop_last': False} + loader_ = DataLoader(dataset_, collate_fn=collate_fn, **loader_params) + + propagate(loader_, seq_name) diff --git a/imcui/third_party/gim/datasets/walk/utils.py b/imcui/third_party/gim/datasets/walk/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b508111ad324f575a5a67f02aa6391c1cd8a7b --- /dev/null +++ b/imcui/third_party/gim/datasets/walk/utils.py @@ -0,0 +1,316 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import math + +import cv2 +import torch +import random +import numpy as np + +from albumentations.augmentations import functional as F + +from datasets.utils import get_divisible_wh + + +def fast_make_matching_robust_fitting_figure(data, b_id=0, transpose=False): + robust_fitting = True if 'inliers' in list(data.keys()) and data['inliers'] is not None else False + + gray0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.uint8) + gray1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.uint8) + kpts0 = data['mkpts0_f'] + kpts1 = data['mkpts1_f'] + + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy() + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy() + + if transpose: + gray0 = cv2.rotate(gray0, cv2.ROTATE_90_COUNTERCLOCKWISE) + gray1 = cv2.rotate(gray1, cv2.ROTATE_90_COUNTERCLOCKWISE) + + h0, w0 = data['hw0_i'] + h1, w1 = data['hw1_i'] + kpts0_new = np.copy(kpts0) + kpts1_new = np.copy(kpts1) + kpts0_new[:, 0], kpts0_new[:, 1] = kpts0[:, 1], w0 - kpts0[:, 0] + kpts1_new[:, 0], kpts1_new[:, 1] = kpts1[:, 1], w1 - kpts1[:, 0] + kpts0, kpts1 = kpts0_new, kpts1_new + (h0, w0), (h1, w1) = (w0, h0), (w1, h1) + else: + (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] + + rows = 3 + margin = 2 + h, w = max(h0, h1), max(w0, w1) + H, W = margin * (rows + 1) + h * rows, margin * 3 + w * 2 + + # canvas + out = 255 * np.ones((H, W), np.uint8) + + wx = [margin, margin + w0, margin + w + margin, margin + w + margin + w1] + hx = lambda row: margin * row + h * (row-1) + out = np.stack([out] * 3, -1) + + sh = hx(row=1) + color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().numpy() * 255).round().astype(np.uint8) + color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().numpy() * 255).round().astype(np.uint8) + if transpose: + color0 = cv2.rotate(color0, cv2.ROTATE_90_COUNTERCLOCKWISE) + color1 = cv2.rotate(color1, cv2.ROTATE_90_COUNTERCLOCKWISE) + out[sh: sh + h0, wx[0]: wx[1]] = color0 + out[sh: sh + h1, wx[2]: wx[3]] = color1 + + # only show keypoints + sh = hx(row=2) + mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) + out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) + out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) + for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): + # display line end-points as circles + c = (230, 216, 132) + cv2.circle(out, (x0, y0+sh), 1, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + w, y1+sh), 1, c, -1, lineType=cv2.LINE_AA) + + # show keypoints and correspondences + sh = hx(row=3) + mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) + out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) + out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) + for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): + c = (159, 212, 252) + cv2.line(out, (x0, y0+sh), (x1 + margin + w, y1+sh), color=c, thickness=1, lineType=cv2.LINE_AA) + for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): + # display line end-points as circles + c = (230, 216, 132) + cv2.circle(out, (x0, y0+sh), 2, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + w, y1+sh), 2, c, -1, lineType=cv2.LINE_AA) + + # Big text. + text = [ + f' ', + f'#Matches {len(kpts0)}', + f'#Matches {sum(data["inliers"][b_id])}' if robust_fitting else '', + ] + sc = min(H / 640., 1.0) + Ht = int(30 * sc) # text height + txt_color_fg = (255, 255, 255) # white + txt_color_bg = (0, 0, 0) # black + for i, t in enumerate(text): + cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA) + + fingerprint = [ + 'Dataset: {}'.format(data['dataset_name'][b_id]), + 'Scene ID: {}'.format(data['scene_id'][b_id]), + 'Pair ID: {}'.format(data['pair_id'][b_id]), + 'co-visible: {:.4f}/{:.4f}'.format(data['covisible0'], + data['covisible1']), + 'Image sizes: {} - {}'.format( + tuple(reversed(data['imsize0'][b_id])) if transpose and isinstance(data['imsize0'][b_id], (list, tuple, np.ndarray)) and len(data['imsize0'][b_id]) >= 2 else data['imsize0'][b_id], + tuple(reversed(data['imsize1'][b_id])) if transpose and isinstance(data['imsize1'][b_id], (list, tuple, np.ndarray)) and len(data['imsize1'][b_id]) >= 2 else data['imsize1'][b_id]), + 'Pair names: {}:{}'.format(data['pair_names'][0].split('/')[-1], + data['pair_names'][1].split('/')[-1]), + 'Rand Scale: {} - {}'.format(data['rands0'], + data['rands1']), + 'Offset: {} - {}'.format(data['offset0'].cpu().numpy(), + data['offset1'].cpu().numpy()), + 'Fliped: {} - {}'.format(data['hflip0'], + data['hflip1']), + 'Transposed: {}'.format(transpose) + ] + sc = min(H / 1280., 1.0) + Ht = int(18 * sc) # text height + txt_color_fg = (255, 255, 255) # white + txt_color_bg = (0, 0, 0) # black + for i, t in enumerate(reversed(fingerprint)): + cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_fg, 1, cv2.LINE_AA) + + return out[h+margin:] + + +def eudist(a, b): + aa = np.sum(a ** 2, axis=-1, keepdims=True) + bb = np.sum(b ** 2, axis=-1, keepdims=True).T + cc = a @ b.T + dist = aa + bb - 2*cc + return dist + + +def covision(kpts, size): + return (kpts[:, 0].max() - kpts[:, 0].min()) * \ + (kpts[:, 1].max() - kpts[:, 1].min()) / \ + (size[0] * size[1] + 1e-8) + + +view = lambda x: x.view([('', x.dtype)] * x.shape[1]) + + +def intersected(x, y): + intersected_ = np.intersect1d(view(x), view(y)) + z = intersected_.view(x.dtype).reshape(-1, x.shape[1]) + return z + + +def imread_color(path, augment_fn=None, read_size=None, source=None): + if augment_fn is None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) if source is None else source + image = cv2.resize(image, read_size) if read_size is not None else image + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if source is None else image + else: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) if source is None else source + image = cv2.resize(image, read_size) if read_size is not None else image + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if source is None else image + image = augment_fn(image) + return image # (h, w) + + +def get_resized_wh(w, h, resize, aug_prob): + nh, nw = resize + sh, sw = nh / h, nw / w + # scale = min(sh, sw) + scale = random.choice([sh, sw]) if aug_prob != 1.0 else min(sh, sw) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size[0], pad_size[1]), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + elif inp.ndim == 3: + padded = np.zeros((pad_size[0], pad_size[1], inp.shape[-1]), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + else: + raise NotImplementedError() + + if ret_mask: + mask = np.zeros((pad_size[0], pad_size[1]), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + + return padded, mask + + +def read_images(path, max_resize, df=None, padding=True, augment_fn=None, aug_prob=0.0, flip_prob=1.0, + is_left=None, upper_cornor=None, read_size=None, image=None): + """ + Args: + path: string + max_resize (int): max image size after resied + df (int, optional): image size division factor. + NOTE: this will change the final image size after img_resize + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + aug_prob (float, optional): probability of applying augment_fn + flip_prob (float, optional): probability of flipping images + is_left (bool, optional): if set to 'True', it is left image, otherwise is right image + upper_cornor (tuple, optional): upper left corner of the image + read_size (int, optional): read image size + image (callable, optional): input image + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + assert max_resize is not None + assert isinstance(max_resize, list) + if len(max_resize) == 1: max_resize = max_resize * 2 + + w_new, h_new = get_divisible_wh(max_resize[0], max_resize[1], df) + max_resize = [h_new, w_new] + + image = imread_color(path, augment_fn, read_size, image) # (h,w,3) image is RGB + + # resize image + w, h = image.shape[1], image.shape[0] + if (h > max_resize[0]) or (w > max_resize[1]): + w_new, h_new = get_resized_wh(w, h, max_resize, aug_prob) # make max(w, h) to max_size + else: + w_new, h_new = w, h + + # random resize + if random.uniform(0, 1) > aug_prob: + # random rescale + ratio = max(h / max_resize[0], w / max_resize[1]) + if type(is_left) == bool: + if is_left: + low, upper = (0.6 / ratio, 1.0 / ratio) if ratio < 1.0 else (0.6, 1.0) + else: + low, upper = (1.0 / ratio, 1.4 / ratio) if ratio < 1.0 else (1.0, 1.4) + else: + low, upper = (0.6 / ratio, 1.4 / ratio) if ratio < 1.0 else (0.6, 1.4) + if not is_left and upper_cornor is not None: + corner = upper_cornor[2:] + upper = min(upper, min(max_resize[0]/corner[1], max_resize[1]/corner[0])) + rands = random.uniform(low, upper) + w_new, h_new = map(lambda x: x*rands, [w_new, h_new]) + w_new, h_new = get_divisible_wh(w_new, h_new, df) # make image divided by df and must <= max_size + else: + rands = 1 + w_new, h_new = get_divisible_wh(w_new, h_new, df) + # width, height = w_new, h_new + # h_start = w_start = 0 + + if upper_cornor is not None: + upper_cornor = upper_cornor[:2] + + # random crop + if h_new > max_resize[0]: + height = max_resize[0] + h_start = int(random.uniform(0, 1) * (h_new - max_resize[0])) + if upper_cornor is not None: + h_start = min(h_start, math.floor(upper_cornor[1]*(h_new/h))) + else: + height = h_new + h_start = 0 + + if w_new > max_resize[1]: + width = max_resize[1] + w_start = int(random.uniform(0, 1) * (w_new - max_resize[1])) + if upper_cornor is not None: + w_start = min(w_start, math.floor(upper_cornor[0]*(w_new/w))) + else: + width = w_new + w_start = 0 + + w_new, h_new = map(int, [w_new, h_new]) + width, height = map(int, [width, height]) + + image = cv2.resize(image, (w_new, h_new)) # (w',h',3) + image = image[h_start:h_start+height, w_start:w_start+width] + + scale = [w / w_new, h / h_new] + offset = [w_start, h_start] + + # vertical flip + if random.uniform(0, 1) > flip_prob: + hflip = F.hflip_cv2 if image.ndim == 3 and image.shape[2] > 1 and image.dtype == np.uint8 else F.hflip + image = hflip(image) + image = F.vflip(image) + hflip = True + vflip = True + else: + hflip = False + vflip = False + + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + + # padding + mask = None + if padding: + image, _ = pad_bottom_right(image, max_resize, ret_mask=False) + gray, mask = pad_bottom_right(gray, max_resize, ret_mask=True) + mask = torch.from_numpy(mask) + + gray = torch.from_numpy(gray).float()[None] / 255 # (1,h,w) + image = torch.from_numpy(image).float() / 255 # (h,w,3) + image = image.permute(2, 0, 1) # (3,h,w) + + offset = torch.tensor(offset, dtype=torch.float) + scale = torch.tensor(scale, dtype=torch.float) + resize = [height, width] + + return gray, image, scale, rands, offset, hflip, vflip, resize, mask diff --git a/imcui/third_party/gim/datasets/walk/video_loader.py b/imcui/third_party/gim/datasets/walk/video_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..71fe131fd2b60a383099d1bf29ebf7bdb52e95c2 --- /dev/null +++ b/imcui/third_party/gim/datasets/walk/video_loader.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import os +import cv2 +import torch + +from os.path import join +from torch.utils.data import Dataset + + +def collate_fn(batch): + batch = list(filter(lambda x: x is not None, batch)) + return torch.utils.data.dataloader.default_collate(batch) + + +class WALKDataset(Dataset): + + def __init__(self, data_root, vs, ids, checkpoint, opt): + super().__init__() + + self.vs = vs + self.ids = ids[checkpoint:] + + old_image_root = join(data_root, 'image_1080p', opt.scene_name) + new_image_root = join(data_root, 'image_1080p', opt.scene_name.strip()) + if not os.path.exists(new_image_root): + if os.path.exists(old_image_root): + os.rename(old_image_root, new_image_root) + else: + os.makedirs(new_image_root, exist_ok=True) + self.image_root = new_image_root + + def __len__(self): + return len(self.ids) + + def __getitem__(self, idx): + idx0, idx1 = self.ids[idx] + + # get image + img_path0 = join(self.image_root, '{}.png'.format(idx0)) + if not os.path.exists(img_path0): + rgb0 = self.vs[idx0] + rgb0_is_good = False + else: + rgb0 = cv2.imread(img_path0) + rgb0_is_good = True + if rgb0 is None: + rgb0 = self.vs[idx0] + rgb0_is_good = False + + img_path1 = join(self.image_root, '{}.png'.format(idx1)) + if not os.path.exists(img_path1): + rgb1 = self.vs[idx1] + rgb1_is_good = False + else: + rgb1 = cv2.imread(img_path1) + rgb1_is_good = True + if rgb1 is None: + rgb1 = self.vs[idx1] + rgb1_is_good = False + + return {'idx': idx, 'idx0': idx0, 'idx1': idx1, 'rgb0': rgb0, 'rgb1': rgb1, + 'img_path0': img_path0, 'img_path1': img_path1, + 'rgb0_is_good':rgb0_is_good, 'rgb1_is_good': rgb1_is_good} diff --git a/imcui/third_party/gim/datasets/walk/video_streamer.py b/imcui/third_party/gim/datasets/walk/video_streamer.py new file mode 100644 index 0000000000000000000000000000000000000000..aafb63966a1f350109069b894e8277311f7d8213 --- /dev/null +++ b/imcui/third_party/gim/datasets/walk/video_streamer.py @@ -0,0 +1,69 @@ +import math + +from pathlib import Path +from torchvision.io import VideoReader + + +class VideoStreamer: + """ Class to help process image streams. Four types of possible inputs:" + 1.) USB Webcam. + 2.) An IP camera + 3.) A directory of images (files in directory matching 'image_glob'). + 4.) A video file, such as an .mp4 or .avi file. + """ + def __init__(self, basedir, resize, df, skip, vrange=None, image_glob=None, max_length=1000000): + """ + The function takes in a directory, a resize value, a skip value, a glob value, and a + max length value. + + The function then checks if the directory is a number, if it is, it sets the cap to + a video capture of the directory. + + If the directory starts with http or rtsp, it sets the cap to a video capture of the + directory. + + If the directory is a directory, it sets the listing to a list of the directory. + + If the directory is a file, it sets the cap to a video capture of the directory. + + If the directory is none of the above, it raises a value error. + + If the directory is a camera and the cap is not opened, it raises an IO error. + + Args: + basedir: The directory where the images or video file are stored. + resize: The size of the image to be returned. + df: The frame rate of the video. + skip: This is the number of frames to skip between each frame that is read. + vrange: Video time range + image_glob: A list of glob patterns to match the images in the directory. + max_length: The maximum number of frames to read from the video. Defaults to + 1000000 + """ + if vrange is None: + vrange = [0, -1] + + self.listing = [] + self.skip = skip + + if Path(basedir).exists(): + self.video = VideoReader(basedir, 'video') + meta = self.video.get_metadata() + seconds = math.floor(meta['video']['duration'][0]) + self.fps = int(meta['video']['fps'][0]) + start, end = max(0, vrange[0]), min(seconds, vrange[1]) + end = seconds if end == -1 else end + assert start < end, 'Invalid video range' + self.range = [start, end] + self.listing = range(start*self.fps, end*self.fps+1) + self.listing = self.listing[::self.skip] + + else: + raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir)) + + def __len__(self): + return len(self.listing) + + def __getitem__(self, i): + image = next(self.video.seek(i/self.fps))['data'].permute(1, 2, 0).numpy() + return image diff --git a/imcui/third_party/gim/datasets/walk/walk.py b/imcui/third_party/gim/datasets/walk/walk.py new file mode 100644 index 0000000000000000000000000000000000000000..623b70b89e4075108eeb19a8308185e83b66d4e1 --- /dev/null +++ b/imcui/third_party/gim/datasets/walk/walk.py @@ -0,0 +1,516 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import os +import cv2 +import torch +import random +import numpy as np +import torch.nn.functional as F + +from tqdm import tqdm +from os import listdir +from pathlib import Path +from functools import reduce +from datetime import datetime +from argparse import ArgumentParser +from os.path import join, isdir, exists + +from datasets.dataset import RGBDDataset + +from datasets.walk import cfg +from datasets.walk.utils import covision, intersected, read_images +from datasets.walk.utils import fast_make_matching_robust_fitting_figure + +parse_mtd = lambda name: name.parent.stem.split()[1] +parse_skip = lambda name: int(str(name).split(os.sep)[-1].rpartition('SP')[-1].strip().rpartition(' ')[0]) +parse_resize = lambda name: str(name).split(os.sep)[-2].rpartition('[R]')[-1].rpartition('[S]')[0].strip() + +create_table = lambda x, y, w: dict(zip(np.round(x) + np.round(y) * w, list(range(len(x))))) + + +class WALKDataset(RGBDDataset): + def __init__(self, + root_dir, # data root dit + npz_root, # data info, like, overlap, image_path, depth_path + seq_name, # current sequence + mode, # train or val or test + max_resize, # max edge after resize + df, # general is 8 for ResNet w/ pre 3-layers + padding, # padding image for batch training + augment_fn, # augmentation function + max_samples, # max sample in current sequence + **kwargs): + super().__init__() + + self.mode = mode + self.root_dir = root_dir + self.scene_path = join(root_dir, seq_name) + + pseudo_labels = kwargs.get('PSEUDO_LABELS', None) + npz_paths = [join(npz_root, x) for x in pseudo_labels] + npz_paths = [x for x in npz_paths if exists(x)] + npz_names = [{d[:int(d.split()[-1])]: Path(path, d) for d in listdir(path) if isdir(join(path, d))} for path in npz_paths] + npz_paths = [name_dict[seq_name] for name_dict in npz_names if seq_name in name_dict.keys()] + + self.propagating = kwargs.get('PROPAGATING', False) + + if self.propagating and len(npz_paths) != 24: + print(f'{seq_name} has {len(npz_paths)} pseudo labels, but 24 are expected.') + exit(0) + + self.scale = 1 / df + self.scene_id = seq_name + self.skips = sorted(list({parse_skip(name) for name in npz_paths})) + self.resizes = sorted(list({parse_resize(name) for name in npz_paths})) + self.methods = sorted(list({parse_mtd(name) for name in npz_paths}))[::-1] + + self.min_final_matches = kwargs.get('MIN_FINAL_MATCHES', None) + self.min_filter_matches = kwargs.get('MIN_FILTER_MATCHES', None) + + pproot = kwargs.get('PROPAGATE_ROOT', None) + ppid = ' '.join(self.methods + list(map(str, self.skips)) + self.resizes + [f'FM {self.min_filter_matches}', f'PM {self.min_final_matches}']) + self.pproot = join(pproot, ppid, seq_name) + + if not self.propagating: + assert exists(self.pproot) + elif not exists(self.pproot): + os.makedirs(self.pproot, exist_ok=True) + + image_root = kwargs.get('VIDEO_IMAGE_ROOT', None) + self.image_root = join(image_root, seq_name) + if not exists(self.image_root): + os.makedirs(self.image_root, exist_ok=True) + + self.step = kwargs.get('STEP', None) + self.pix_thr = kwargs.get('PIX_THR', None) + self.fix_matches = kwargs.get('FIX_MATCHES', None) + + source_root = kwargs.get('SOURCE_ROOT', None) + + scap = cv2.VideoCapture(join(source_root, seq_name + '.mp4')) + self.pseudo_size = [int(scap.get(3)), int(scap.get(4))] + source_fps = int(scap.get(5)) + + video_path = join(root_dir, seq_name + '.mp4') + vcap = cv2.VideoCapture(video_path) + self.frame_size = [int(vcap.get(3)), int(vcap.get(4))] + + if self.propagating: + nums = {skip: [] for skip in self.skips} + idxs = {skip: [] for skip in self.skips} + self.path = {skip: [] for skip in self.skips} + for npz_path in npz_paths: + skip = parse_skip(npz_path) + assert exists(npz_path / 'nums.npy') + with open(npz_path / 'nums.npy', 'rb') as f: + npz = np.load(f) + nums[skip].append(npz) + assert exists(npz_path / 'idxs.npy') + with open(npz_path / 'idxs.npy', 'rb') as f: + npz = np.load(f) + idxs[skip].append(npz) + self.path[skip].append(npz_path) + + ids1 = reduce(intersected, [idxs[nums > self.min_filter_matches] for nums, idxs in zip(nums[self.skips[-1]], idxs[self.skips[-1]])]) + continue1 = np.array([x in ids1[:, 0] for x in (ids1[:, 0] + self.skips[-1] * 1)]) + ids2 = reduce(intersected, idxs[self.skips[-2]]) + continue2 = np.array([x in ids2[:, 0] for x in ids1[:, 0]]) + continue2 = continue2 & np.array([x in ids2[:, 0] for x in (ids1[:, 0] + self.skips[-2] * 1)]) + ids3 = reduce(intersected, idxs[self.skips[-3]]) + continue3 = np.array([x in ids3[:, 0] for x in ids1[:, 0]]) + continue3 = continue3 & np.array([x in ids3[:, 0] for x in (ids1[:, 0] + self.skips[-3] * 1)]) + continue3 = continue3 & np.array([x in ids3[:, 0] for x in (ids1[:, 0] + self.skips[-3] * 2)]) + continue3 = continue3 & np.array([x in ids3[:, 0] for x in (ids1[:, 0] + self.skips[-3] * 3)]) + continues = continue1 & continue2 & continue3 + ids = ids1[continues] + pair_ids = np.array(list(zip(ids[:, 0], np.clip(ids[:, 0]+self.step*self.skips[-1], a_min=ids[0, 0], a_max=ids[-1, 1])))) if self.step > 0 else ids + pair_ids = pair_ids[(pair_ids[:, 1] - pair_ids[:, 0]) >= self.skips[-1]] + else: + pair_ids = np.array([tuple(map(int, x.split('.npy')[0].split('_'))) for x in os.listdir(self.pproot) if x.endswith('.npy')]) + + if (max_samples > 0) and (len(pair_ids) > max_samples): + random_state = random.getstate() + np_random_state = np.random.get_state() + random.seed(3407) + np.random.seed(3407) + pair_ids = pair_ids[sorted(np.random.randint(len(pair_ids), size=max_samples))] + random.setstate(random_state) + np.random.set_state(np_random_state) + + # remove unvalid pairs from self.pproot/bad_pairs.txt + pair_ids = set(map(tuple, pair_ids.tolist())) + + if self.propagating: + assert not exists(join(self.pproot, 'bad_pairs.txt')) + + if exists(join(self.pproot, 'bad_pairs.txt')): + with open(join(self.pproot, 'bad_pairs.txt'), 'r') as f: + unvalid_pairs = set([tuple(map(int, line.split())) for line in f.readlines()]) + self.unvalid_pairs_num = len(unvalid_pairs) if not self.propagating else 'N/A' + pair_ids = pair_ids - unvalid_pairs + + self.valid_pairs_num = len(pair_ids) if not self.propagating else 'N/A' + + self.pair_ids = list(map(list, pair_ids)) # List[List[int, int]] + + # parameters for image resizing, padding and depthmap padding + if mode == 'train': assert max_resize is not None + + self.df = df + self.max_resize = max_resize + self.padding = padding + + # for training LoFTR + self.augment_fn = augment_fn if mode == 'train' else None + + def __len__(self): + return len(self.pair_ids) + + def propagate(self, idx0, idx1, skips): + """ + Args: + idx0: (int) index of the first frame + idx1: (int) index of the second frame + skips: (List) + + Returns: + """ + skip = skips[-1] # 40 + indices = [skip * (i + 1) + idx0 for i in range((idx1 - idx0) // skip)] + if (not indices) or (idx0 != indices[0]): indices = [idx0] + indices + if idx1 != indices[-1]: indices = indices + [idx1] + indices = list(zip(indices[:-1], indices[1:])) + + # [(N', 4), (N'', 4), ...] + labels = [] + ids = [idx0] + while indices: + pair = indices.pop(0) # (tuple) + if pair[0] == pair[1]: break + label = [] + if (pair[-1] - pair[0]) == skip: + tmp = self.dump(skip, pair) + if len(tmp) > 0: label.append(tmp) # (ndarray) (N, 4) + if skips[:-1]: + _label_, id0, id1 = self.propagate(pair[0], pair[1], skips[:-1]) + if (id0, id1) == pair: label.append(_label_) # (ndarray) (M, 4) + if label: + label = np.concatenate(label, axis=0) # (ndarray) (N+M, 4) + labels.append(label) + ids += [pair[1]] + if len(labels) > 1: + _labels_ = self.link(labels[0], labels[1]) + if _labels_ is not None: + labels = [_labels_] + ids = [ids[0], ids[-1]] + else: + labels.pop(-1) + ids.pop(-1) + indices = [(pair[0], pair[1]-skips[0])] + + if len(labels) == 1 and len(ids) == 2: + return labels[0], ids[0], ids[-1] + else: + return None, None, None + + def link(self, label0, label1): + """ + Args: + label0: (ndarray) N x 4 + label1: (ndarray) M x 4 + + Returns: (ndarray) (N', 4) + """ + # get keypoints in left, middle and right frame + left_t0 = label0[:, :2] # (N, 2) + mid_t0 = label0[:, 2:] # (N, 2) + mid_t1 = label1[:, :2] # (M, 2) + right_t1 = label1[:, 2:] # (M, 2) + + mid0_table = create_table(mid_t0[:, 0], mid_t0[:, 1], self.pseudo_size[0]) + mid1_table = create_table(mid_t1[:, 0], mid_t1[:, 1], self.pseudo_size[0]) + + keys = {*mid0_table} & {*mid1_table} + + i = np.array([mid0_table[k] for k in keys]) + j = np.array([mid1_table[k] for k in keys]) + + # remove repeat matches + ij = np.unique(np.vstack((i, j)), axis=1) + + if ij.shape[1] < self.min_final_matches: return None + + # get the new pseudo labels + pseudo_label = np.concatenate([left_t0[ij[0]], right_t1[ij[1]]], axis=1) # (N', 4) + + return pseudo_label + + def dump(self, skip, pair): + """ + Args: + skip: + pair: + + Returns: pseudo_label (N, 4) + """ + labels = [] + for path in self.path[skip]: + p = path / '{}.npy'.format(str(np.array(pair))) + if exists(p): + with open(p, 'rb') as f: + labels.append(np.load(f)) + + if len(labels) > 0: labels = np.concatenate(labels, axis=0).astype(np.float32) # (N, 4) + + return labels + + def __getitem__(self, idx): + idx0, idx1 = self.pair_ids[idx] + + pppath = join(self.pproot, '{}_{}.npy'.format(idx0, idx1)) + + if self.propagating and exists(pppath): + return None + + # check propagation + if not self.propagating: + assert exists(pppath), f'{pppath} does not exist' + + if not exists(pppath): + pseudo_label, idx0, idx1 = self.propagate(idx0, idx1, self.skips) + + if idx1 - idx0 == self.skips[-1]: + pseudo_label, idx0, idx1 = self.propagate(idx0, idx1, self.skips[:-1]) + + if idx1 - idx0 == self.skips[-2]: + pseudo_label, idx0, idx1 = self.propagate(idx0, idx1, self.skips[:-2]) + + if pseudo_label is None: + _idx0_, _idx1_ = self.pair_ids[idx] + with open(join(self.pproot, 'bad_pairs.txt'), 'a') as f: + f.write('{} {}\n'.format(_idx0_, _idx1_)) + return None + + _, mask = cv2.findFundamentalMat(pseudo_label[:, :2], pseudo_label[:, 2:], cv2.USAC_MAGSAC, ransacReprojThreshold=1.0, confidence=0.999999, maxIters=1000) + mask = mask.ravel() > 0 + pseudo_label = pseudo_label[mask] + + if len(pseudo_label) < 64 or (idx1 - idx0) == self.skips[-3]: + _idx0_, _idx1_ = self.pair_ids[idx] + with open(join(self.pproot, 'bad_pairs.txt'), 'a') as f: + f.write('{} {}\n'.format(_idx0_, _idx1_)) + return None + else: + with open(pppath, 'wb') as f: + np.save(f, np.concatenate((np.array([[idx0, idx1, idx0, idx1]]).astype(np.float32), pseudo_label), axis=0)) + else: + with open(pppath, 'rb') as f: + pseudo_label = np.load(f) + idx0, idx1 = pseudo_label[0].astype(np.int64)[:2].tolist() + pseudo_label = pseudo_label[1:] + + if self.propagating: + return None + + pseudo_label *= (np.array(self.frame_size * 2) / np.array(self.pseudo_size * 2))[None] + + # get image + img_path0 = join(self.image_root, '{}.png'.format(idx0)) + color0 = cv2.imread(img_path0) + + img_path1 = join(self.image_root, '{}.png'.format(idx1)) + color1 = cv2.imread(img_path1) + + width0, height0 = self.frame_size + width1, height1 = self.frame_size + + left_upper_cornor = pseudo_label[:, :2].min(axis=0) + left_low_corner = pseudo_label[:, :2].max(axis=0) + left_corner = np.concatenate([left_upper_cornor, left_low_corner], axis=0) + right_upper_cornor = pseudo_label[:, 2:].min(axis=0) + right_low_corner = pseudo_label[:, 2:].max(axis=0) + right_corner = np.concatenate([right_upper_cornor, right_low_corner], axis=0) + + # Prepare variables + image0, color0, scale0, rands0, offset0, hlip0, vflip0, resize0, mask0 = read_images( + None, self.max_resize, self.df, self.padding, + np.random.choice([self.augment_fn, None], p=[0.5, 0.5]), + aug_prob=1.0, is_left=True, + upper_cornor=left_corner, + read_size=self.frame_size, image=color0) + image1, color1, scale1, rands1, offset1, hlip1, vflip1, resize1, mask1 = read_images( + None, self.max_resize, self.df, self.padding, + np.random.choice([self.augment_fn, None], p=[0.5, 0.5]), + aug_prob=1.0, is_left=False, + upper_cornor=right_corner, + read_size=self.frame_size, image=color1) + + # warp keypoints by scale, offset and hlip + pseudo_label = torch.tensor(pseudo_label, dtype=torch.float) + left = (pseudo_label[:, :2] / scale0[None] - offset0[None]) + left[:, 0] = resize0[1] - 1 - left[:, 0] if hlip0 else left[:, 0] + left[:, 1] = resize0[0] - 1 - left[:, 1] if vflip0 else left[:, 1] + right = (pseudo_label[:, 2:] / scale1[None] - offset1[None]) + right[:, 0] = resize1[1] - 1 - right[:, 0] if hlip1 else right[:, 0] + right[:, 1] = resize1[0] - 1 - right[:, 1] if vflip1 else right[:, 1] + + mask = (left[:, 0] >= 0) & (left[:, 0]*self.scale <= (resize0[1]*self.scale - 1)) & \ + (left[:, 1] >= 0) & (left[:, 1]*self.scale <= (resize0[0]*self.scale - 1)) & \ + (right[:, 0] >= 0) & (right[:, 0]*self.scale <= (resize1[1]*self.scale - 1)) & \ + (right[:, 1] >= 0) & (right[:, 1]*self.scale <= (resize1[0]*self.scale - 1)) + left, right = left[mask], right[mask] + + pseudo_label = torch.cat([left, right], dim=1) + pseudo_label = torch.unique(pseudo_label, dim=0) + + fix_pseudo_label = torch.zeros(self.fix_matches, 4, dtype=pseudo_label.dtype) + fix_pseudo_label[:len(pseudo_label)] = pseudo_label + + # read image size + imsize0 = torch.tensor([height0, width0], dtype=torch.long) + imsize1 = torch.tensor([height1, width1], dtype=torch.long) + resize0 = torch.tensor(resize0, dtype=torch.long) + resize1 = torch.tensor(resize1, dtype=torch.long) + + data = { + # image 0 + 'image0': image0, + 'color0': color0, + 'imsize0': imsize0, + 'offset0': offset0, + 'resize0': resize0, + 'depth0': torch.ones((1600, 1600), dtype=torch.float), + 'hflip0': hlip0, + 'vflip0': vflip0, + + # image 1 + 'image1': image1, + 'color1': color1, + 'imsize1': imsize1, + 'offset1': offset1, + 'resize1': resize1, + 'depth1': torch.ones((1600, 1600), dtype=torch.float), + 'hflip1': hlip1, + 'vflip1': vflip1, + + # image transform + 'pseudo_labels': fix_pseudo_label, + 'gt': False, + 'zs': True, + + # image transform + 'T_0to1': torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float), + 'T_1to0': torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float), + 'K0': torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float), + 'K1': torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float), + # pair information + 'scale0': scale0 / scale0, + 'scale1': scale1 / scale1, + 'rands0': rands0, + 'rands1': rands1, + 'dataset_name': 'WALK', + 'scene_id': '{:30}'.format(self.scene_id[:min(30, len(self.scene_id)-1)]), + 'pair_id': f'{idx0}-{idx1}', + 'pair_names': ('{}.png'.format(idx0), + '{}.png'.format(idx1)), + 'covisible0': covision(pseudo_label[:, :2], resize0).item(), + 'covisible1': covision(pseudo_label[:, 2:], resize1).item(), + } + + item = super(WALKDataset, self).__getitem__(idx) + item.update(data) + data = item + + if mask0 is not None: + if self.scale: + # noinspection PyArgumentList + [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), + scale_factor=self.scale, + mode='nearest', + recompute_scale_factor=False)[0].bool() + data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) + data.update({'mask0_i': mask0, 'mask1_i': mask1}) + + return data + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('seq_names', type=str, nargs='+') + args = parser.parse_args() + + train_cfg = cfg.DATASET.TRAIN + + base_input = { + 'df': 8, + 'mode': 'train', + 'augment_fn': None, + 'max_resize': [1280, 720], + 'padding': cfg.DATASET.TRAIN.PADDING, + 'max_samples': cfg.DATASET.TRAIN.MAX_SAMPLES, + 'min_overlap_score': cfg.DATASET.TRAIN.MIN_OVERLAP_SCORE, + 'max_overlap_score': cfg.DATASET.TRAIN.MAX_OVERLAP_SCORE + } + + cfg_input = { + k: getattr(train_cfg, k) + for k in [ + 'DATA_ROOT', 'NPZ_ROOT', 'STEP', 'PIX_THR', 'FIX_MATCHES', 'SOURCE_ROOT', + 'MAX_CANDIDATE_MATCHES', 'MIN_FINAL_MATCHES', 'MIN_FILTER_MATCHES', + 'VIDEO_IMAGE_ROOT', 'PROPAGATE_ROOT', 'PSEUDO_LABELS' + ] + } + + if os.path.isfile(args.seq_names[0]): + with open(args.seq_names[0], 'r') as f: + seq_names = [line.strip() for line in f.readlines()] + else: + seq_names = args.seq_names + + for seq_name in seq_names: + input_ = { + **base_input, + **cfg_input, + 'root_dir': cfg_input['DATA_ROOT'], + 'npz_root': cfg_input['NPZ_ROOT'], + 'seq_name': seq_name + } + + dataset = WALKDataset(**input_) + + random.seed(3407) + np.random.seed(3407) + + samples = list(range(len(dataset))) + num = 10 + samples = random.sample(samples, num) + for idx_ in tqdm(samples[:num], ncols=80, bar_format="{l_bar}{bar:3}{r_bar}", total=num, + desc=f'[ {seq_name[:min(10, len(seq_name)-1)]:<10} ] [ {dataset.valid_pairs_num:<5} / {dataset.valid_pairs_num+dataset.unvalid_pairs_num:<5} ]',): + data_ = dataset[idx_] + + if data_ is None: continue + + pseudo_labels_ = data_['pseudo_labels'] + mask_ = pseudo_labels_.sum(dim=1) > 0 + pseudo_label_ = pseudo_labels_[mask_].cpu().numpy() + data_['mkpts0_f'] = pseudo_label_[:, :2] + data_['mkpts1_f'] = pseudo_label_[:, 2:] + data_['hw0_i'] = data_['image0'].shape[-2:] + data_['hw1_i'] = data_['image1'].shape[-2:] + data_['image0'] = data_['image0'][None] + data_['image1'] = data_['image1'][None] + data_['color0'] = data_['color0'][None] + data_['color1'] = data_['color1'][None] + idx0_, idx1_ = data_['pair_id'].split('-') + idx0_, idx1_ = map(int, [idx0_, idx1_]) + + out = fast_make_matching_robust_fitting_figure(data_, transpose=True) + save_dir = Path('dump/walk') / seq_name + if not exists(save_dir): save_dir.mkdir(parents=True, exist_ok=True) + cv2.imwrite(join(save_dir, '{:8d} [{}] {:8d} {:3d}.png'.format( + idx0_, + datetime.utcnow().strftime('%Y-%m-%d %H-%M-%S %f')[:-3], + idx1_, + idx1_ - idx0_ + )), cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) diff --git a/imcui/third_party/gim/demo.py b/imcui/third_party/gim/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..4af940a9719931852f0d517c2f44732a3d724846 --- /dev/null +++ b/imcui/third_party/gim/demo.py @@ -0,0 +1,524 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import cv2 +import torch +import argparse +import warnings +import numpy as np +import matplotlib.pyplot as plt +import torchvision.transforms.functional as F + +from os.path import join +from tools import get_padding_size +from networks.loftr.loftr import LoFTR +from networks.loftr.misc import lower_config +from networks.loftr.config import get_cfg_defaults +from networks.dkm.models.model_zoo.DKMv3 import DKMv3 +from networks.lightglue.superpoint import SuperPoint +from networks.lightglue.models.matchers.lightglue import LightGlue + +DEFAULT_MIN_NUM_MATCHES = 4 +DEFAULT_RANSAC_MAX_ITER = 10000 +DEFAULT_RANSAC_CONFIDENCE = 0.999 +DEFAULT_RANSAC_REPROJ_THRESHOLD = 8 +DEFAULT_RANSAC_METHOD = "USAC_MAGSAC" + +RANSAC_ZOO = { + "RANSAC": cv2.RANSAC, + "USAC_FAST": cv2.USAC_FAST, + "USAC_MAGSAC": cv2.USAC_MAGSAC, + "USAC_PROSAC": cv2.USAC_PROSAC, + "USAC_DEFAULT": cv2.USAC_DEFAULT, + "USAC_FM_8PTS": cv2.USAC_FM_8PTS, + "USAC_ACCURATE": cv2.USAC_ACCURATE, + "USAC_PARALLEL": cv2.USAC_PARALLEL, +} + + +def read_image(path, grayscale=False): + if grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise ValueError(f'Cannot read image {path}.') + if not grayscale and len(image.shape) == 3: + image = image[:, :, ::-1] # BGR to RGB + return image + + +def resize_image(image, size, interp): + assert interp.startswith('cv2_') + if interp.startswith('cv2_'): + interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper()) + h, w = image.shape[:2] + if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): + interp = cv2.INTER_LINEAR + resized = cv2.resize(image, size, interpolation=interp) + # elif interp.startswith('pil_'): + # interp = getattr(PIL.Image, interp[len('pil_'):].upper()) + # resized = PIL.Image.fromarray(image.astype(np.uint8)) + # resized = resized.resize(size, resample=interp) + # resized = np.asarray(resized, dtype=image.dtype) + else: + raise ValueError( + f'Unknown interpolation {interp}.') + return resized + + +def fast_make_matching_figure(data, b_id): + color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) + color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) + gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) + gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) + kpts0 = data['mkpts0_f'].cpu().detach().numpy() + kpts1 = data['mkpts1_f'].cpu().detach().numpy() + mconf = data['mconf'].cpu().detach().numpy() + inliers = data['inliers'] + + rows = 2 + margin = 2 + (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] + h = max(h0, h1) + H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 + + # canvas + out = 255 * np.ones((H, W), np.uint8) + + wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] + hx = lambda row: margin * row + h * (row-1) + out = np.stack([out] * 3, -1) + + sh = hx(row=1) + out[sh: sh + h0, wx[0]: wx[1]] = color0 + out[sh: sh + h1, wx[2]: wx[3]] = color1 + + sh = hx(row=2) + out[sh: sh + h0, wx[0]: wx[1]] = color0 + out[sh: sh + h1, wx[2]: wx[3]] = color1 + mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) + for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): + c = (0, 255, 0) + cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) + + return out + + +def fast_make_matching_overlay(data, b_id): + color0 = (data['color0'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) + color1 = (data['color1'][b_id].permute(1, 2, 0).cpu().detach().numpy() * 255).round().astype(np.uint8) # (rH, rW, 3) + gray0 = cv2.cvtColor(color0, cv2.COLOR_RGB2GRAY) + gray1 = cv2.cvtColor(color1, cv2.COLOR_RGB2GRAY) + kpts0 = data['mkpts0_f'].cpu().detach().numpy() + kpts1 = data['mkpts1_f'].cpu().detach().numpy() + mconf = data['mconf'].cpu().detach().numpy() + inliers = data['inliers'] + + rows = 2 + margin = 2 + (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] + h = max(h0, h1) + H, W = margin * (rows + 1) + h * rows, margin * 3 + w0 + w1 + + # canvas + out = 255 * np.ones((H, W), np.uint8) + + wx = [margin, margin + w0, margin + w0 + margin, margin + w0 + margin + w1] + hx = lambda row: margin * row + h * (row-1) + out = np.stack([out] * 3, -1) + + sh = hx(row=1) + out[sh: sh + h0, wx[0]: wx[1]] = color0 + out[sh: sh + h1, wx[2]: wx[3]] = color1 + + sh = hx(row=2) + out[sh: sh + h0, wx[0]: wx[1]] = color0 + out[sh: sh + h1, wx[2]: wx[3]] = color1 + mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) + for (x0, y0), (x1, y1) in zip(mkpts0[inliers], mkpts1[inliers]): + c = (0, 255, 0) + cv2.line(out, (x0, y0 + sh), (x1 + margin + w0, y1 + sh), color=c, thickness=1, lineType=cv2.LINE_AA) + cv2.circle(out, (x0, y0 + sh), 3, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + w0, y1 + sh), 3, c, -1, lineType=cv2.LINE_AA) + + return out + + +def preprocess(image: np.ndarray, grayscale: bool = False, resize_max: int = None, + dfactor: int = 8): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + + if resize_max: + scale = resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x*scale)) for x in size) + image = resize_image(image, size_new, 'cv2_area') + scale = np.array(size) / np.array(size_new) + + if grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple(map( + lambda x: int(x // dfactor * dfactor), + image.shape[-2:])) + image = F.resize(image, size=size_new) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + +def compute_geom(data, + ransac_method=DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold=DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence=DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter=DEFAULT_RANSAC_MAX_ITER, + ) -> dict: + + mkpts0 = data["mkpts0_f"].cpu().detach().numpy() + mkpts1 = data["mkpts1_f"].cpu().detach().numpy() + + if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES: + return {} + + h1, w1 = data["hw0_i"] + + geo_info = {} + + F, inliers = cv2.findFundamentalMat( + mkpts0, + mkpts1, + method=RANSAC_ZOO[ransac_method], + ransacReprojThreshold=ransac_reproj_threshold, + confidence=ransac_confidence, + maxIters=ransac_max_iter, + ) + if F is not None: + geo_info["Fundamental"] = F.tolist() + + H, _ = cv2.findHomography( + mkpts1, + mkpts0, + method=RANSAC_ZOO[ransac_method], + ransacReprojThreshold=ransac_reproj_threshold, + confidence=ransac_confidence, + maxIters=ransac_max_iter, + ) + if H is not None: + geo_info["Homography"] = H.tolist() + _, H1, H2 = cv2.stereoRectifyUncalibrated( + mkpts0.reshape(-1, 2), + mkpts1.reshape(-1, 2), + F, + imgSize=(w1, h1), + ) + geo_info["H1"] = H1.tolist() + geo_info["H2"] = H2.tolist() + + return geo_info + + +def wrap_images(img0, img1, geo_info, geom_type): + img0 = img0[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] + img1 = img1[0].permute((1, 2, 0)).cpu().detach().numpy()[..., ::-1] + + h1, w1, _ = img0.shape + h2, w2, _ = img1.shape + + rectified_image0 = img0 + rectified_image1 = None + H = np.array(geo_info["Homography"]) + F = np.array(geo_info["Fundamental"]) + + title = [] + if geom_type == "Homography": + rectified_image1 = cv2.warpPerspective( + img1, H, (img0.shape[1], img0.shape[0]) + ) + title = ["Image 0", "Image 1 - warped"] + elif geom_type == "Fundamental": + H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"]) + rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1)) + rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2)) + title = ["Image 0 - warped", "Image 1 - warped"] + else: + print("Error: Unknown geometry type") + + fig = plot_images( + [rectified_image0.squeeze(), rectified_image1.squeeze()], + title, + dpi=300, + ) + + img = fig2im(fig) + + plt.close(fig) + + return img + + +def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5): + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. + dpi: + size: + pad: + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + figsize = (size * n, size * 6 / 5) if size is not None else None + fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) + + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + + fig.tight_layout(pad=pad) + + return fig + + +def fig2im(fig): + fig.canvas.draw() + w, h = fig.canvas.get_width_height() + buf_ndarray = np.frombuffer(fig.canvas.buffer_rgba(), dtype="u1") + # noinspection PyArgumentList + im = buf_ndarray.reshape(h, w, 4) + return im + + +if __name__ == '__main__': + model_zoo = ['gim_dkm', 'gim_loftr', 'gim_lightglue'] + + # model + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default='gim_dkm', choices=model_zoo) + args = parser.parse_args() + + # device + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # load model + ckpt = None + model = None + detector = None + if args.model == 'gim_dkm': + ckpt = 'gim_dkm_100h.ckpt' + model = DKMv3(weights=None, h=672, w=896) + elif args.model == 'gim_loftr': + ckpt = 'gim_loftr_50h.ckpt' + model = LoFTR(lower_config(get_cfg_defaults())['loftr']) + elif args.model == 'gim_lightglue': + ckpt = 'gim_lightglue_100h.ckpt' + detector = SuperPoint({ + 'max_num_keypoints': 2048, + 'force_num_keypoints': True, + 'detection_threshold': 0.0, + 'nms_radius': 3, + 'trainable': False, + }) + model = LightGlue({ + 'filter_threshold': 0.1, + 'flash': False, + 'checkpointed': True, + }) + + # weights path + checkpoints_path = join('weights', ckpt) + + # load state dict + if args.model == 'gim_dkm': + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + if 'encoder.net.fc' in k: + state_dict.pop(k) + model.load_state_dict(state_dict) + + elif args.model == 'gim_loftr': + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + model.load_state_dict(state_dict) + + elif args.model == 'gim_lightglue': + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict.pop(k) + if k.startswith('superpoint.'): + state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) + detector.load_state_dict(state_dict) + + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('superpoint.'): + state_dict.pop(k) + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + model.load_state_dict(state_dict) + + # eval mode + if detector is not None: + detector = detector.eval().to(device) + model = model.eval().to(device) + + name0 = 'a1' + name1 = 'a2' + postfix = '.png' + image_dir = join('assets', 'demo') + img_path0 = join(image_dir, name0 + postfix) + img_path1 = join(image_dir, name1 + postfix) + + image0 = read_image(img_path0) + image1 = read_image(img_path1) + image0, scale0 = preprocess(image0) + image1, scale1 = preprocess(image1) + + image0 = image0.to(device)[None] + image1 = image1.to(device)[None] + + b_ids, mconf, kpts0, kpts1 = None, None, None, None + data = dict(color0=image0, color1=image1, image0=image0, image1=image1) + + if args.model == 'gim_dkm': + orig_width0, orig_height0, pad_left0, pad_right0, pad_top0, pad_bottom0 = get_padding_size(image0, 672, 896) + orig_width1, orig_height1, pad_left1, pad_right1, pad_top1, pad_bottom1 = get_padding_size(image1, 672, 896) + image0_ = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0)) + image1_ = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1)) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dense_matches, dense_certainty = model.match(image0_, image1_) + sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000) + + height0, width0 = image0_.shape[-2:] + height1, width1 = image1_.shape[-2:] + + kpts0 = sparse_matches[:, :2] + kpts0 = torch.stack(( + width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,) + kpts1 = sparse_matches[:, 2:] + kpts1 = torch.stack(( + width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,) + b_ids = torch.where(mconf[None])[0] + + # before padding + kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None] + kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None] + mask_ = (kpts0[:, 0] > 0) & \ + (kpts0[:, 1] > 0) & \ + (kpts1[:, 0] > 0) & \ + (kpts1[:, 1] > 0) + mask_ = mask_ & \ + (kpts0[:, 0] <= (orig_width0 - 1)) & \ + (kpts1[:, 0] <= (orig_width1 - 1)) & \ + (kpts0[:, 1] <= (orig_height0 - 1)) & \ + (kpts1[:, 1] <= (orig_height1 - 1)) + + mconf = mconf[mask_] + b_ids = b_ids[mask_] + kpts0 = kpts0[mask_] + kpts1 = kpts1[mask_] + + elif args.model == 'gim_loftr': + with torch.no_grad(): + model(data) + kpts0 = data['mkpts0_f'] + kpts1 = data['mkpts1_f'] + b_ids = data['m_bids'] + mconf = data['mconf'] + + elif args.model == 'gim_lightglue': + gray0 = read_image(img_path0, grayscale=True) + gray1 = read_image(img_path1, grayscale=True) + gray0 = preprocess(gray0, grayscale=True)[0] + gray1 = preprocess(gray1, grayscale=True)[0] + + gray0 = gray0.to(device)[None] + gray1 = gray1.to(device)[None] + scale0 = torch.tensor(scale0).to(device)[None] + scale1 = torch.tensor(scale1).to(device)[None] + + data.update(dict(gray0=gray0, gray1=gray1)) + + size0 = torch.tensor(data["gray0"].shape[-2:][::-1])[None] + size1 = torch.tensor(data["gray1"].shape[-2:][::-1])[None] + + data.update(dict(size0=size0, size1=size1)) + data.update(dict(scale0=scale0, scale1=scale1)) + + pred = {} + with torch.no_grad(): + pred.update({k + '0': v for k, v in detector({ + "image": data["gray0"], + }).items()}) + pred.update({k + '1': v for k, v in detector({ + "image": data["gray1"], + }).items()}) + pred.update(model({**pred, **data, + **{'image_size0': data['size0'], + 'image_size1': data['size1']}})) + + kpts0 = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])]) + kpts1 = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])]) + m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0] + matches = pred['matches'] + bs = data['image0'].size(0) + kpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) + kpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) + b_ids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) + mconf = torch.cat(pred['scores']) + + # robust fitting + _, mask = cv2.findFundamentalMat(kpts0.cpu().detach().numpy(), + kpts1.cpu().detach().numpy(), + cv2.USAC_MAGSAC, ransacReprojThreshold=1.0, + confidence=0.999999, maxIters=10000) + mask = mask.ravel() > 0 + + data.update({ + 'hw0_i': image0.shape[-2:], + 'hw1_i': image1.shape[-2:], + 'mkpts0_f': kpts0, + 'mkpts1_f': kpts1, + 'm_bids': b_ids, + 'mconf': mconf, + 'inliers': mask, + }) + + # save visualization + alpha = 0.5 + out = fast_make_matching_figure(data, b_id=0) + overlay = fast_make_matching_overlay(data, b_id=0) + out = cv2.addWeighted(out, 1 - alpha, overlay, alpha, 0) + cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_match.png'), out[..., ::-1]) + + geom_info = compute_geom(data) + wrapped_images = wrap_images(image0, image1, geom_info, + "Homography") + cv2.imwrite(join(image_dir, f'{name0}_{name1}_{args.model}_warp.png'), wrapped_images) diff --git a/imcui/third_party/gim/hloc/__init__.py b/imcui/third_party/gim/hloc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f1296f84f73f31af302dbd1e407bc179569563 --- /dev/null +++ b/imcui/third_party/gim/hloc/__init__.py @@ -0,0 +1,30 @@ +import logging +from packaging import version + +__version__ = '1.5' + +formatter = logging.Formatter( + fmt='[%(asctime)s %(name)s %(levelname)s] %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') +handler = logging.StreamHandler() +handler.setFormatter(formatter) +handler.setLevel(logging.INFO) + +logger = logging.getLogger("hloc") +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False + +try: + import pycolmap +except ImportError: + logger.warning('pycolmap is not installed, some features may not work.') +else: + minimal_version = version.parse('0.3.0') + found_version = pycolmap.__version__ + if found_version != 'dev': + if version.parse(found_version) < minimal_version: + logger.warning( + 'hloc now requires pycolmap>=%s but found pycolmap==%s, ' + 'please upgrade with `pip install --upgrade pycolmap`', + minimal_version, found_version) diff --git a/imcui/third_party/gim/hloc/extract_features.py b/imcui/third_party/gim/hloc/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..46f765dc40ef0a28adf0a672cb950e820102e097 --- /dev/null +++ b/imcui/third_party/gim/hloc/extract_features.py @@ -0,0 +1,326 @@ +import argparse +import torch +from pathlib import Path +from typing import Dict, List, Union, Optional +import h5py +from types import SimpleNamespace +import cv2 +import numpy as np +from tqdm import tqdm +import pprint +import collections.abc as collections +import PIL.Image +import glob + +from . import extractors, logger +from .utils.base_model import dynamic_load +from .utils.parsers import parse_image_lists +from .utils.io import read_image, list_h5_names + + +''' +A set of standard configurations that can be directly selected from the command +line using their name. Each is a dictionary with the following entries: + - output: the name of the feature file that will be generated. + - model: the model configuration, as passed to a feature extractor. + - preprocessing: how to preprocess the images read from disk. +''' +confs = { + 'gim_superpoint': { + 'output': 'feats-gim-superpoint-n2048-r1920', + 'model': { + 'name': 'superpoint', + 'nms_radius': 3, + 'max_keypoints': 2048, + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': 1920, + }, + }, + 'superpoint_aachen': { + 'output': 'feats-superpoint-n4096-r1024', + 'model': { + 'name': 'superpoint', + 'nms_radius': 3, + 'max_keypoints': 4096, + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': 1024, + }, + }, + # Resize images to 1600px even if they are originally smaller. + # Improves the keypoint localization if the images are of good quality. + 'superpoint_max': { + 'output': 'feats-superpoint-n4096-rmax1600', + 'model': { + 'name': 'superpoint', + 'nms_radius': 3, + 'max_keypoints': 4096, + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': 1600, + 'resize_force': True, + }, + }, + 'superpoint_inloc': { + 'output': 'feats-superpoint-n4096-r1600', + 'model': { + 'name': 'superpoint', + 'nms_radius': 4, + 'max_keypoints': 4096, + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': 2048, + }, + }, + 'r2d2': { + 'output': 'feats-r2d2-n5000-r1024', + 'model': { + 'name': 'r2d2', + 'max_keypoints': 5000, + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': 1024, + }, + }, + 'd2net-ss': { + 'output': 'feats-d2net-ss', + 'model': { + 'name': 'd2net', + 'multiscale': False, + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': 1600, + }, + }, + 'sift': { + 'output': 'feats-sift', + 'model': { + 'name': 'dog' + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': 1600, + }, + }, + 'sosnet': { + 'output': 'feats-sosnet', + 'model': { + 'name': 'dog', + 'descriptor': 'sosnet' + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': 1600, + }, + }, + 'disk': { + 'output': 'feats-disk', + 'model': { + 'name': 'disk', + 'max_keypoints': 5000, + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': 1600, + }, + }, + # Global descriptors + 'dir': { + 'output': 'global-feats-dir', + 'model': {'name': 'dir'}, + 'preprocessing': {'resize_max': 1024}, + }, + 'netvlad': { + 'output': 'global-feats-netvlad', + 'model': {'name': 'netvlad'}, + 'preprocessing': {'resize_max': 1024}, + }, + 'openibl': { + 'output': 'global-feats-openibl', + 'model': {'name': 'openibl'}, + 'preprocessing': {'resize_max': 1024}, + }, + 'cosplace': { + 'output': 'global-feats-cosplace', + 'model': {'name': 'cosplace'}, + 'preprocessing': {'resize_max': 1024}, + } +} + + +def resize_image(image, size, interp): + if interp.startswith('cv2_'): + interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper()) + h, w = image.shape[:2] + if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): + interp = cv2.INTER_LINEAR + resized = cv2.resize(image, size, interpolation=interp) + elif interp.startswith('pil_'): + interp = getattr(PIL.Image, interp[len('pil_'):].upper()) + resized = PIL.Image.fromarray(image.astype(np.uint8)) + resized = resized.resize(size, resample=interp) + resized = np.asarray(resized, dtype=image.dtype) + else: + raise ValueError( + f'Unknown interpolation {interp}.') + return resized + + +class ImageDataset(torch.utils.data.Dataset): + default_conf = { + 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'], + 'grayscale': False, + 'resize_max': None, + 'resize_force': False, + 'interpolation': 'cv2_area', # pil_linear is more accurate but slower + } + + def __init__(self, root, conf, paths=None): + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.root = root + + if paths is None: + paths = [] + for g in conf.globs: + paths += glob.glob( + (Path(root) / '**' / g).as_posix(), recursive=True) + if len(paths) == 0: + raise ValueError(f'Could not find any image in root: {root}.') + paths = sorted(set(paths)) + self.names = [Path(p).relative_to(root).as_posix() for p in paths] + logger.info(f'Found {len(self.names)} images in root {root}.') + else: + if isinstance(paths, (Path, str)): + self.names = parse_image_lists(paths) + elif isinstance(paths, collections.Iterable): + self.names = [p.as_posix() if isinstance(p, Path) else p + for p in paths] + else: + raise ValueError(f'Unknown format for path argument {paths}.') + + for name in self.names: + if not (root / name).exists(): + raise ValueError( + f'Image {name} does not exists in root: {root}.') + + def __getitem__(self, idx): + name = self.names[idx] + image = read_image(self.root / name, self.conf.grayscale) + image = image.astype(np.float32) + size = image.shape[:2][::-1] + + if self.conf.resize_max and (self.conf.resize_force + or max(size) > self.conf.resize_max): + scale = self.conf.resize_max / max(size) + size_new = tuple(int(round(x*scale)) for x in size) + image = resize_image(image, size_new, self.conf.interpolation) + + if self.conf.grayscale: + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = image / 255. + + data = { + 'image': image, + 'original_size': np.array(size), + } + return data + + def __len__(self): + return len(self.names) + + +@torch.no_grad() +def main(conf: Dict, + image_dir: Path, + export_dir: Optional[Path] = None, + as_half: bool = True, + image_list: Optional[Union[Path, List[str]]] = None, + feature_path: Optional[Path] = None, + overwrite: bool = False, + model=None) -> Path: + logger.info('Extracting local features with configuration:' + f'\n{pprint.pformat(conf)}') + + dataset = ImageDataset(image_dir, conf['preprocessing'], image_list) + if feature_path is None: + feature_path = Path(export_dir, conf['output']+'.h5') + feature_path.parent.mkdir(exist_ok=True, parents=True) + skip_names = set(list_h5_names(feature_path) + if feature_path.exists() and not overwrite else ()) + dataset.names = [n for n in dataset.names if n not in skip_names] + if len(dataset.names) == 0: + logger.info('Skipping the extraction.') + return feature_path + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if model is None: + Model = dynamic_load(extractors, conf['model']['name']) + model = Model(conf['model']) + model = model.eval().to(device) + + loader = torch.utils.data.DataLoader( + dataset, num_workers=1, shuffle=False, pin_memory=True) + for idx, data in enumerate(tqdm(loader)): + name = dataset.names[idx] + pred = model({'image': data['image'].to(device, non_blocking=True)}) + pred = {k: v[0].cpu().numpy() for k, v in pred.items()} + + pred['image_size'] = original_size = data['original_size'][0].numpy() + if 'keypoints' in pred: + size = np.array(data['image'].shape[-2:][::-1]) + scales = (original_size / size).astype(np.float32) + pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5 + if 'scales' in pred: + pred['scales'] *= scales.mean() + # add keypoint uncertainties scaled to the original resolution + uncertainty = getattr(model, 'detection_noise', 1) * scales.mean() + + if as_half: + for k in pred: + dt = pred[k].dtype + if (dt == np.float32) and (dt != np.float16): + pred[k] = pred[k].astype(np.float16) + + with h5py.File(str(feature_path), 'a', libver='latest') as fd: + try: + if name in fd: + del fd[name] + grp = fd.create_group(name) + for k, v in pred.items(): + grp.create_dataset(k, data=v) + if 'keypoints' in pred: + grp['keypoints'].attrs['uncertainty'] = uncertainty + except OSError as error: + if 'No space left on device' in error.args[0]: + logger.error( + 'Out of disk space: storing features on disk can take ' + 'significant space, did you enable the as_half flag?') + del grp, fd[name] + raise error + + del pred + + logger.info('Finished exporting features.') + return feature_path + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image_dir', type=Path, required=True) + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--conf', type=str, default='superpoint_aachen', + choices=list(confs.keys())) + parser.add_argument('--as_half', action='store_true') + parser.add_argument('--image_list', type=Path) + parser.add_argument('--feature_path', type=Path) + args = parser.parse_args() + main(confs[args.conf], args.image_dir, args.export_dir, args.as_half) diff --git a/imcui/third_party/gim/hloc/extractors/__init__.py b/imcui/third_party/gim/hloc/extractors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/gim/hloc/match_dense.py b/imcui/third_party/gim/hloc/match_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..219240def7ab5d4788623cc711520b6acb4825f9 --- /dev/null +++ b/imcui/third_party/gim/hloc/match_dense.py @@ -0,0 +1,549 @@ +import os +import shutil +from tqdm import tqdm +import numpy as np +import h5py +import torch +from pathlib import Path +from typing import Dict, Iterable, Optional, List, Tuple, Union, Set +import pprint +import argparse +import torchvision.transforms.functional as F +from types import SimpleNamespace +from collections import defaultdict +from scipy.spatial import KDTree +from collections import Counter +from itertools import chain + +from . import matchers, logger +from .utils.base_model import dynamic_load +from .utils.parsers import parse_retrieval, names_to_pair +from .match_features import find_unique_new_pairs +from .extract_features import read_image, resize_image +from .utils.io import list_h5_names + +confs = { + 'gim_dkm': { + 'output': 'matches-gim', + 'model': { + 'name': 'dkm', + 'weights': 'gim_dkm_100h.ckpt' + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': None, + 'dfactor': 1 + }, + 'max_error': 2, # max error for assigned keypoints (in px) + 'cell_size': 8, # size of quantization patch (max 1 kp/patch) + }, +} + + +def to_cpts(kpts, ps): + if ps > 0.0: + kpts = np.round(np.round((kpts + 0.5) / ps) * ps - 0.5, 2) + return [tuple(cpt) for cpt in kpts] + + +def assign_keypoints(kpts: np.ndarray, + other_cpts: Union[List[Tuple], np.ndarray], + max_error: float, + update: bool = False, + ref_bins: Optional[List[Counter]] = None, + scores: Optional[np.ndarray] = None, + cell_size: Optional[int] = None): + if not update: + if len(other_cpts) == 0: return np.array([], dtype=np.int64) + # Without update this is just a NN search + dist, kpt_ids = KDTree(np.array(other_cpts)).query(kpts) + valid = (dist <= max_error) + kpt_ids[~valid] = -1 + return kpt_ids + else: + ps = cell_size if cell_size is not None else max_error + ps = max(ps, max_error) + # With update we quantize and bin (optionally) + assert isinstance(other_cpts, list) + kpt_ids = [] + cpts = to_cpts(kpts, ps) + bpts = to_cpts(kpts, int(max_error)) + cp_to_id = {val: i for i, val in enumerate(other_cpts)} + for i, (cpt, bpt) in enumerate(zip(cpts, bpts)): + try: + kid = cp_to_id[cpt] + except KeyError: + kid = len(cp_to_id) + cp_to_id[cpt] = kid + other_cpts.append(cpt) + if ref_bins is not None: + ref_bins.append(Counter()) + if ref_bins is not None: + score = scores[i] if scores is not None else 1 + ref_bins[cp_to_id[cpt]][bpt] += score + kpt_ids.append(kid) + return np.array(kpt_ids) + + +def get_grouped_ids(array): + # Group array indices based on its values + # all duplicates are grouped as a set + idx_sort = np.argsort(array) + sorted_array = array[idx_sort] + _, ids, _ = np.unique(sorted_array, return_counts=True, + return_index=True) + res = np.split(idx_sort, ids[1:]) + return res + + +def get_unique_matches(match_ids, scores): + if len(match_ids.shape) == 1: + return [0] + + isets1 = get_grouped_ids(match_ids[:, 0]) + isets2 = get_grouped_ids(match_ids[:, 1]) + uid1s = [ids[scores[ids].argmax()] for ids in isets1 if len(ids) > 0] + uid2s = [ids[scores[ids].argmax()] for ids in isets2 if len(ids) > 0] + uids = list(set(uid1s).intersection(uid2s)) + return match_ids[uids], scores[uids] + + +def matches_to_matches0(matches, scores): + if len(matches) == 0: + return np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.float16) + n_kps0 = np.max(matches[:, 0]) + 1 + matches0 = -np.ones((n_kps0,)) + scores0 = np.zeros((n_kps0,)) + matches0[matches[:, 0]] = matches[:, 1] + scores0[matches[:, 0]] = scores + return matches0.astype(np.int32), scores0.astype(np.float16) + + +def kpids_to_matches0(kpt_ids0, kpt_ids1, scores): + valid = (kpt_ids0 != -1) & (kpt_ids1 != -1) + matches = np.dstack([kpt_ids0[valid], kpt_ids1[valid]]) + matches = matches.reshape(-1, 2) + scores = scores[valid] + + # Remove n-to-1 matches + matches, scores = get_unique_matches(matches, scores) + return matches_to_matches0(matches, scores) + + +def scale_keypoints(kpts, scale): + if np.any(scale != 1.0): + kpts *= kpts.new_tensor(scale) + return kpts + + +class ImagePairDataset(torch.utils.data.Dataset): + default_conf = { + 'grayscale': True, + 'resize_max': 1024, + 'dfactor': 8, + 'cache_images': False, + } + + def __init__(self, image_dir, conf, pairs): + self.image_dir = image_dir + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.pairs = sorted(pairs) if pairs else pairs + if self.conf.cache_images: + image_names = set(sum(pairs, ())) # unique image names in pairs + logger.info( + f'Loading and caching {len(image_names)} unique images.') + self.images = {} + self.scales = {} + for name in tqdm(image_names): + image = read_image(self.image_dir / name, self.conf.grayscale) + self.images[name], self.scales[name] = self.preprocess(image) + + def preprocess(self, image: np.ndarray): + image = image.astype(np.float32, copy=False) + size = image.shape[:2][::-1] + scale = np.array([1.0, 1.0]) + + if self.conf.resize_max: + scale = self.conf.resize_max / max(size) + if scale < 1.0: + size_new = tuple(int(round(x*scale)) for x in size) + image = resize_image(image, size_new, 'cv2_area') + scale = np.array(size) / np.array(size_new) + + if self.conf.grayscale: + assert image.ndim == 2, image.shape + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = torch.from_numpy(image / 255.0).float() + + # assure that the size is divisible by dfactor + size_new = tuple(map( + lambda x: int(x // self.conf.dfactor * self.conf.dfactor), + image.shape[-2:])) + image = F.resize(image, size=size_new) + scale = np.array(size) / np.array(size_new)[::-1] + return image, scale + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + if self.conf.cache_images: + image0, scale0 = self.images[name0], self.scales[name0] + image1, scale1 = self.images[name1], self.scales[name1] + else: + image0 = read_image(self.image_dir / name0, self.conf.grayscale) + image1 = read_image(self.image_dir / name1, self.conf.grayscale) + image0, scale0 = self.preprocess(image0) + image1, scale1 = self.preprocess(image1) + return image0, image1, scale0, scale1, name0, name1 + + +@torch.no_grad() +def match_dense(conf: Dict, + pairs: List[Tuple[str, str]], + image_dir: Path, + match_path: Path, # out + existing_refs: Optional[List] = []): + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Model = dynamic_load(matchers, conf['model']['name']) + model = Model(conf['model']).eval().to(device) + + dataset = ImagePairDataset(image_dir, conf["preprocessing"], pairs) + loader = torch.utils.data.DataLoader( + dataset, num_workers=16, batch_size=1, shuffle=False) + + logger.info("Performing dense matching...") + with h5py.File(str(match_path), 'a') as fd: + for data in tqdm(loader, smoothing=.1): + # load image-pair data + image0, image1, scale0, scale1, (name0,), (name1,) = data + scale0, scale1 = scale0[0].numpy(), scale1[0].numpy() + image0, image1 = image0.to(device), image1.to(device) + + # match semi-dense + # for consistency with pairs_from_*: refine kpts of image0 + if name0 in existing_refs: + # special case: flip to enable refinement in query image + pred = model({'image0': image1, 'image1': image0, 'name0': name1, 'name1': name0}) + pred = {**pred, + 'keypoints0': pred['keypoints1'], + 'keypoints1': pred['keypoints0']} + else: + # usual case + # # 在 image1 上 grid sample 关键点, 在 image0 上预测 refine 关键点 + pred = model({'image0': image0, 'image1': image1, 'name0': name0, 'name1': name1}) + + # Rescale keypoints and move to cpu + kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] + kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5 + kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5 + kpts0 = kpts0.cpu().numpy() + kpts1 = kpts1.cpu().numpy() + scores = pred['scores'].cpu().numpy() + + # Write matches and matching scores in hloc format + pair = names_to_pair(name0, name1) + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + + # Write dense matching output + grp.create_dataset('keypoints0', data=kpts0) + grp.create_dataset('keypoints1', data=kpts1) + grp.create_dataset('scores', data=scores) + del model, loader + + +# default: quantize all! +def load_keypoints(conf: Dict, + feature_paths_refs: List[Path], + quantize: Optional[set] = None): + name2ref = {n: i for i, p in enumerate(feature_paths_refs) + for n in list_h5_names(p)} + + existing_refs = set(name2ref.keys()) + if quantize is None: + quantize = existing_refs # quantize all + if len(existing_refs) > 0: + logger.info(f'Loading keypoints from {len(existing_refs)} images.') + + # Load query keypoints + cpdict = defaultdict(list) + bindict = defaultdict(list) + for name in existing_refs: + with h5py.File(str(feature_paths_refs[name2ref[name]]), 'r') as fd: + kps = fd[name]['keypoints'].__array__() + if name not in quantize: + cpdict[name] = kps + else: + if 'scores' in fd[name].keys(): + kp_scores = fd[name]['scores'].__array__() + else: + # we set the score to 1.0 if not provided + # increase for more weight on reference keypoints for + # stronger anchoring + kp_scores = \ + [1.0 for _ in range(kps.shape[0])] + # bin existing keypoints of reference images for association + assign_keypoints( + kps, cpdict[name], conf['max_error'], True, bindict[name], + kp_scores, conf['cell_size']) + return cpdict, bindict + + +def aggregate_matches( + conf: Dict, + pairs: List[Tuple[str, str]], + match_path: Path, + feature_path: Path, + required_queries: Optional[Set[str]] = None, + max_kps: Optional[int] = None, + cpdict: Dict[str, Iterable] = defaultdict(list), + bindict: Dict[str, List[Counter]] = defaultdict(list)): + if required_queries is None: + required_queries = set(sum(pairs, ())) + # default: do not overwrite existing features in feature_path! + required_queries -= set(list_h5_names(feature_path)) + + # if an entry in cpdict is provided as np.ndarray we assume it is fixed + required_queries -= set( + [k for k, v in cpdict.items() if isinstance(v, np.ndarray)]) + + # sort pairs for reduced RAM + pairs_per_q = Counter(list(chain(*pairs))) + pairs_score = [min(pairs_per_q[i], pairs_per_q[j]) for i, j in pairs] + pairs = [p for _, p in sorted(zip(pairs_score, pairs))] + + if len(required_queries) > 0: + logger.info(f'Aggregating keypoints for {len(required_queries)} images.') + n_kps = 0 + with h5py.File(str(match_path), 'a') as fd: + for name0, name1 in tqdm(pairs, smoothing=.1): + pair = names_to_pair(name0, name1) + grp = fd[pair] + kpts0 = grp['keypoints0'].__array__() + kpts1 = grp['keypoints1'].__array__() + scores = grp['scores'].__array__() + + # Aggregate local features + update0 = name0 in required_queries + update1 = name1 in required_queries + + # in localization we do not want to bin the query kp + # assumes that the query is name0! + if update0 and not update1 and max_kps is None: + max_error0 = cell_size0 = 0.0 + else: + max_error0 = conf['max_error'] + cell_size0 = conf['cell_size'] + + # Get match ids and extend query keypoints (cpdict) + mkp_ids0 = assign_keypoints(kpts0, cpdict[name0], max_error0, + update0, bindict[name0], scores, + cell_size0) + mkp_ids1 = assign_keypoints(kpts1, cpdict[name1], conf['max_error'], + update1, bindict[name1], scores, + conf['cell_size']) + + # Build matches from assignments + matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores) + + assert kpts0.shape[0] == scores.shape[0] + # del grp['matches0'], grp['matching_scores0'] + grp.create_dataset('matches0', data=matches0) + grp.create_dataset('matching_scores0', data=scores0) + + # Convert bins to kps if finished, and store them + for name in (name0, name1): + pairs_per_q[name] -= 1 + if pairs_per_q[name] > 0 or name not in required_queries: + continue + kp_score = [c.most_common(1)[0][1] for c in bindict[name]] + cpdict[name] = [c.most_common(1)[0][0] for c in bindict[name]] + cpdict[name] = np.array(cpdict[name], dtype=np.float32) + + # Select top-k query kps by score (reassign matches later) + if max_kps: + top_k = min(max_kps, cpdict[name].shape[0]) + top_k = np.argsort(kp_score)[::-1][:top_k] + cpdict[name] = cpdict[name][top_k] + kp_score = np.array(kp_score)[top_k] + + # Write query keypoints + with h5py.File(feature_path, 'a') as kfd: + if name in kfd: + del kfd[name] + kgrp = kfd.create_group(name) + kgrp.create_dataset('keypoints', data=cpdict[name]) + kgrp.create_dataset('score', data=kp_score) + n_kps += cpdict[name].shape[0] + del bindict[name] + + if len(required_queries) > 0: + avg_kp_per_image = round(n_kps / len(required_queries), 1) + logger.info(f'Finished assignment, found {avg_kp_per_image} ' + f'keypoints/image (avg.), total {n_kps}.') + return cpdict + + +def assign_matches( + pairs: List[Tuple[str, str]], + match_path: Path, + keypoints: Union[List[Path], Dict[str, np.array]], + max_error: float): + if isinstance(keypoints, list): + keypoints = load_keypoints({}, keypoints, quantize=set([])) + assert len(set(sum(pairs, ())) - set(keypoints.keys())) == 0 + with h5py.File(str(match_path), 'a') as fd: + for name0, name1 in tqdm(pairs): + pair = names_to_pair(name0, name1) + grp = fd[pair] + kpts0 = grp['keypoints0'].__array__() + kpts1 = grp['keypoints1'].__array__() + scores = grp['scores'].__array__() + + # NN search across cell boundaries + mkp_ids0 = assign_keypoints(kpts0, keypoints[name0], max_error) + mkp_ids1 = assign_keypoints(kpts1, keypoints[name1], max_error) + + matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, + scores) + + # overwrite matches0 and matching_scores0 + del grp['matches0'], grp['matching_scores0'] + grp.create_dataset('matches0', data=matches0) + grp.create_dataset('matching_scores0', data=scores0) + + +@torch.no_grad() +def match_and_assign(conf: Dict, + pairs_path: Path, + image_dir: Path, + match_path: Path, # out + feature_path_q: Path, # out + feature_paths_refs: Optional[List[Path]] = [], + max_kps: Optional[int] = 8192, + overwrite: bool = False) -> Path: + for path in feature_paths_refs: + if not path.exists(): + raise FileNotFoundError(f'Reference feature file {path}.') + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + required_queries = set(sum(pairs, ())) + + name2ref = {n: i for i, p in enumerate(feature_paths_refs) + for n in list_h5_names(p)} + existing_refs = required_queries.intersection(set(name2ref.keys())) + + # images which require feature extraction + required_queries = required_queries - existing_refs + + if feature_path_q.exists(): + existing_queries = set(list_h5_names(feature_path_q)) + feature_paths_refs.append(feature_path_q) + existing_refs = set.union(existing_refs, existing_queries) + if not overwrite: + required_queries = required_queries - existing_queries + + if len(pairs) == 0 and len(required_queries) == 0: + logger.info("All pairs exist. Skipping dense matching.") + return + + # extract semi-dense matches + parts = list(match_path.parts) + match_cache_base = os.sep.join(parts[:-1] + ['cache']) + match_cache_path = os.path.join(match_cache_base, parts[-1]) + if not os.path.exists(match_cache_path): + match_dense(conf, pairs, image_dir, match_path, + existing_refs=existing_refs) + if not os.path.exists(match_cache_base): os.mkdir(match_cache_base) + shutil.copy(str(match_path), str(match_cache_path)) + else: + shutil.copy(str(match_cache_path), str(match_path)) + + logger.info("Assigning matches...") + + # Pre-load existing keypoints + cpdict, bindict = load_keypoints( + conf, feature_paths_refs, + quantize=required_queries) + + # Reassign matches by aggregation + cpdict = aggregate_matches( + conf, pairs, match_path, feature_path=feature_path_q, + required_queries=required_queries, max_kps=max_kps, cpdict=cpdict, + bindict=bindict) + + # Invalidate matches that are far from selected bin by reassignment + if max_kps is not None: + logger.info(f'Reassign matches with max_error={conf["max_error"]}.') + assign_matches(pairs, match_path, cpdict, + max_error=conf['max_error']) + + +@torch.no_grad() +def main(conf: Dict, + pairs: Path, + image_dir: Path, + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, # out + features: Optional[Path] = None, # out + features_ref: Optional[Path] = None, + max_kps: Optional[int] = 8192, + overwrite: bool = False) -> Path: + logger.info('Extracting semi-dense features with configuration:' + f'\n{pprint.pformat(conf)}') + + if features is None: + features = 'feats_' + + if isinstance(features, Path): + features_q = features + if matches is None: + raise ValueError('Either provide both features and matches as Path' + ' or both as names.') + else: + if export_dir is None: + raise ValueError('Provide an export_dir if features and matches' + f' are not file paths: {features}, {matches}.') + features_q = Path(export_dir, + f'{features}{conf["output"]}.h5') + if matches is None: + matches = Path( + export_dir, f'{conf["output"]}_{pairs.stem}.h5') + + if features_ref is None: + features_ref = [] + elif isinstance(features_ref, list): + features_ref = list(features_ref) + elif isinstance(features_ref, Path): + features_ref = [features_ref] + else: + raise TypeError(str(features_ref)) + + match_and_assign(conf, pairs, image_dir, matches, + features_q, features_ref, + max_kps, overwrite) + + return features_q, matches + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--image_dir', type=Path, required=True) + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--matches', type=Path, + default=confs['loftr']['output']) + parser.add_argument('--features', type=str, + default='feats_' + confs['loftr']['output']) + parser.add_argument('--conf', type=str, default='loftr', + choices=list(confs.keys())) + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.image_dir, args.export_dir, + args.matches, args.features) diff --git a/imcui/third_party/gim/hloc/match_features.py b/imcui/third_party/gim/hloc/match_features.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a68e6a8cf11c597ba2fce465c40f5d3df8814f --- /dev/null +++ b/imcui/third_party/gim/hloc/match_features.py @@ -0,0 +1,269 @@ +import argparse +from typing import Union, Optional, Dict, List, Tuple +from pathlib import Path +import pprint +from queue import Queue +from threading import Thread +from functools import partial +from tqdm import tqdm +import h5py +import torch + +from . import matchers, logger +from .utils.base_model import dynamic_load +from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval + + +''' +A set of standard configurations that can be directly selected from the command +line using their name. Each is a dictionary with the following entries: + - output: the name of the match file that will be generated. + - model: the model configuration, as passed to a feature matcher. +''' +confs = { + 'gim_lightglue': { + 'output': 'matches-gim-lightglue', + 'model': { + 'name': 'lightglue', + 'weights': 'gim_lightglue_100h', + }, + 'preprocessing': { # for segmentation + 'grayscale': False, + 'resize_max': None, + 'dfactor': 1 + }, + }, + 'superpoint+lightglue': { + 'output': 'matches-superpoint-lightglue', + 'model': { + 'name': 'lightglue', + 'features': 'superpoint', + }, + }, + 'disk+lightglue': { + 'output': 'matches-disk-lightglue', + 'model': { + 'name': 'lightglue', + 'features': 'disk', + }, + }, + 'superpoint+superglue': { + 'output': 'matches-superglue', + 'model': { + 'name': 'superglue', + 'weights': 'outdoor', + 'sinkhorn_iterations': 50, + }, + }, + 'superglue-fast': { + 'output': 'matches-superglue-it5', + 'model': { + 'name': 'superglue', + 'weights': 'outdoor', + 'sinkhorn_iterations': 5, + }, + }, + 'NN-superpoint': { + 'output': 'matches-NN-mutual-dist.7', + 'model': { + 'name': 'nearest_neighbor', + 'do_mutual_check': True, + 'distance_threshold': 0.7, + }, + }, + 'NN-ratio': { + 'output': 'matches-NN-mutual-ratio.8', + 'model': { + 'name': 'nearest_neighbor', + 'do_mutual_check': True, + 'ratio_threshold': 0.8, + } + }, + 'NN-mutual': { + 'output': 'matches-NN-mutual', + 'model': { + 'name': 'nearest_neighbor', + 'do_mutual_check': True, + }, + }, + 'adalam': { + 'output': 'matches-adalam', + 'model': { + 'name': 'adalam' + }, + } +} + + +class WorkQueue(): + def __init__(self, work_fn, num_threads=1): + self.queue = Queue(num_threads) + self.threads = [ + Thread(target=self.thread_fn, args=(work_fn,)) + for _ in range(num_threads) + ] + for thread in self.threads: + thread.start() + + def join(self): + for thread in self.threads: + self.queue.put(None) + for thread in self.threads: + thread.join() + + def thread_fn(self, work_fn): + item = self.queue.get() + while item is not None: + work_fn(item) + item = self.queue.get() + + def put(self, data): + self.queue.put(data) + + +class FeaturePairsDataset(torch.utils.data.Dataset): + def __init__(self, pairs, feature_path_q, feature_path_r): + self.pairs = pairs + self.feature_path_q = feature_path_q + self.feature_path_r = feature_path_r + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + data = {} + with h5py.File(self.feature_path_q, 'r') as fd: + grp = fd[name0] + for k, v in grp.items(): + data[k+'0'] = torch.from_numpy(v.__array__()).float() + # some matchers might expect an image but only use its size + data['image0'] = torch.empty((1,)+tuple(grp['image_size'])[::-1]) + with h5py.File(self.feature_path_r, 'r') as fd: + grp = fd[name1] + for k, v in grp.items(): + data[k+'1'] = torch.from_numpy(v.__array__()).float() + data['image1'] = torch.empty((1,)+tuple(grp['image_size'])[::-1]) + return data + + def __len__(self): + return len(self.pairs) + + +def writer_fn(inp, match_path): + pair, pred = inp + with h5py.File(str(match_path), 'a', libver='latest') as fd: + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + matches = pred['matches0'][0].cpu().short().numpy() + grp.create_dataset('matches0', data=matches) + if 'matching_scores0' in pred: + scores = pred['matching_scores0'][0].cpu().half().numpy() + grp.create_dataset('matching_scores0', data=scores) + + +def main(conf: Dict, + pairs: Path, features: Union[Path, str], + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, + features_ref: Optional[Path] = None, + overwrite: bool = False, + model = None) -> Path: + + if isinstance(features, Path) or Path(features).exists(): + features_q = features + if matches is None: + raise ValueError('Either provide both features and matches as Path' + ' or both as names.') + else: + if export_dir is None: + raise ValueError('Provide an export_dir if features is not' + f' a file path: {features}.') + features_q = Path(export_dir, features+'.h5') + if matches is None: + matches = Path( + export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5') + + if features_ref is None: + features_ref = features_q + match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite, model=model) + + return matches + + +def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): + '''Avoid to recompute duplicates to save time.''' + pairs = set() + for i, j in pairs_all: + if (j, i) not in pairs: + pairs.add((i, j)) + pairs = list(pairs) + if match_path is not None and match_path.exists(): + with h5py.File(str(match_path), 'r', libver='latest') as fd: + pairs_filtered = [] + for i, j in pairs: + if (names_to_pair(i, j) in fd or + names_to_pair(j, i) in fd or + names_to_pair_old(i, j) in fd or + names_to_pair_old(j, i) in fd): + continue + pairs_filtered.append((i, j)) + return pairs_filtered + return pairs + + +@torch.no_grad() +def match_from_paths(conf: Dict, + pairs_path: Path, + match_path: Path, + feature_path_q: Path, + feature_path_ref: Path, + overwrite: bool = False, + model = None) -> Path: + logger.info('Matching local features with configuration:' + f'\n{pprint.pformat(conf)}') + + if not feature_path_q.exists(): + raise FileNotFoundError(f'Query feature file {feature_path_q}.') + if not feature_path_ref.exists(): + raise FileNotFoundError(f'Reference feature file {feature_path_ref}.') + match_path.parent.mkdir(exist_ok=True, parents=True) + + assert pairs_path.exists(), pairs_path + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + if len(pairs) == 0: + logger.info('Skipping the matching.') + return + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if model is None: + Model = dynamic_load(matchers, conf['model']['name']) + model = Model(conf['model']) + model = model.eval().to(device) + + dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) + loader = torch.utils.data.DataLoader( + dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True) + writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) + + for idx, data in enumerate(tqdm(loader, smoothing=.1)): + data = {k: v if k.startswith('image') + else v.to(device, non_blocking=True) for k, v in data.items()} + pred = model(data) + pair = names_to_pair(*pairs[idx]) + writer_queue.put((pair, pred)) + writer_queue.join() + logger.info('Finished exporting matches.') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--export_dir', type=Path) + parser.add_argument('--features', type=str, + default='feats-superpoint-n4096-r1024') + parser.add_argument('--matches', type=Path) + parser.add_argument('--conf', type=str, default='superglue', + choices=list(confs.keys())) + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir) diff --git a/imcui/third_party/gim/hloc/matchers/__init__.py b/imcui/third_party/gim/hloc/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7edac76f912b1e5ebb0401b6cc7a5d3c64ce963a --- /dev/null +++ b/imcui/third_party/gim/hloc/matchers/__init__.py @@ -0,0 +1,3 @@ +def get_matcher(matcher): + mod = __import__(f'{__name__}.{matcher}', fromlist=['']) + return getattr(mod, 'Model') diff --git a/imcui/third_party/gim/hloc/matchers/dkm.py b/imcui/third_party/gim/hloc/matchers/dkm.py new file mode 100644 index 0000000000000000000000000000000000000000..6afa3d6dfd2af7a9c3adccea2d29ffd92d31e1d0 --- /dev/null +++ b/imcui/third_party/gim/hloc/matchers/dkm.py @@ -0,0 +1,154 @@ +import os +import cv2 +import torch +import warnings +import numpy as np +from os.path import join +from pathlib import Path + +from tools import get_padding_size +from hloc.utils import CLS_DICT, exclude +from ..utils.base_model import BaseModel +from networks.dkm.models.model_zoo.DKMv3 import DKMv3 + + +class LoFTR(BaseModel): + default_conf = { + 'max_num_matches': None, + } + required_inputs = [ + 'image0', + 'image1' + ] + + def _init(self, conf): + self.h = 672 + self.w = 896 + model = DKMv3(None, self.h, self.w, upsample_preds=True) + + checkpoints_path = join('weights', conf['weights']) + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + if 'encoder.net.fc' in k: + state_dict.pop(k) + model.load_state_dict(state_dict) + + self.net = model + + def _forward(self, data): + outputs = Path(os.environ['GIMRECONSTRUCTION']) + segment_root = outputs / '..' / 'segment' + + # For consistency with hloc pairs, we refine kpts in image0! + rename = { + 'keypoints0': 'keypoints1', + 'keypoints1': 'keypoints0', + 'image0': 'image1', + 'image1': 'image0', + 'mask0': 'mask1', + 'mask1': 'mask0', + 'name0': 'name1', + 'name1': 'name0', + } + data_ = {rename[k]: v for k, v in data.items()} + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + image0, image1 = data_['image0'], data_['image1'] + img0, img1 = data_['name0'], data_['name1'] + + # segment image + seg_path0 = join(segment_root, '{}.npy'.format(img0[:-4])) + mask0 = np.load(seg_path0) + if mask0.shape[:2] != image0.shape[-2:]: + mask0 = cv2.resize(mask0, image0.shape[-2:][::-1], + interpolation=cv2.INTER_NEAREST) + mask_0 = mask0 != CLS_DICT[exclude[0]] + for cls in exclude[1:]: + mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) + mask_0 = mask0 + mask_0 = mask_0.astype(np.uint8) + mask_0 = torch.from_numpy((mask_0 == 0).astype(np.uint8)).to(image0.device) + mask_0 = mask_0.float()[None, None] == 0 + image0 = image0 * mask_0 + # segment image + seg_path1 = join(segment_root, '{}.npy'.format(img1[:-4])) + mask1 = np.load(seg_path1) + if mask1.shape != image1.shape[-2:]: + mask1 = cv2.resize(mask1, image1.shape[-2:][::-1], + interpolation=cv2.INTER_NEAREST) + mask_1 = mask1 != CLS_DICT[exclude[0]] + for cls in exclude[1:]: + mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) + mask_1 = mask1 + mask_1 = mask_1.astype(np.uint8) + mask_1 = torch.from_numpy((mask_1 == 0).astype(np.uint8)).to(image1.device) + mask_1 = mask_1.float()[None, None] == 0 + image1 = image1 * mask_1 + + orig_width0, orig_height0, pad_left0, pad_right0, pad_top0, pad_bottom0 = get_padding_size(image0, self.h, self.w) + orig_width1, orig_height1, pad_left1, pad_right1, pad_top1, pad_bottom1 = get_padding_size(image1, self.h, self.w) + image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0)) + image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1)) + + dense_matches, dense_certainty = self.net.match(image0, image1) + sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, 8192) + + m = mconf > 0 + mconf = mconf[m] + sparse_matches = sparse_matches[m] + + height0, width0 = image0.shape[-2:] + height1, width1 = image1.shape[-2:] + + kpts0 = sparse_matches[:, :2] + kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, + height0 * (kpts0[:, 1] + 1) / 2), dim=-1, ) + kpts1 = sparse_matches[:, 2:] + kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, + height1 * (kpts1[:, 1] + 1) / 2), dim=-1, ) + b_ids, i_ids = torch.where(mconf[None]) + + # before padding + kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None] + kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None] + mask = (kpts0[:, 0] > 0) & \ + (kpts0[:, 1] > 0) & \ + (kpts1[:, 0] > 0) & \ + (kpts1[:, 1] > 0) + mask = mask & \ + (kpts0[:, 0] <= (orig_width0 - 1)) & \ + (kpts1[:, 0] <= (orig_width1 - 1)) & \ + (kpts0[:, 1] <= (orig_height0 - 1)) & \ + (kpts1[:, 1] <= (orig_height1 - 1)) + + pred = { + 'keypoints0': kpts0[i_ids], + 'keypoints1': kpts1[i_ids], + 'confidence': mconf[i_ids], + 'batch_indexes': b_ids, + } + + # noinspection PyUnresolvedReferences + scores, b_ids = pred['confidence'], pred['batch_indexes'] + kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] + pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask] + pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask] + + scores = pred['confidence'] + + top_k = self.conf['max_num_matches'] + if top_k is not None and len(scores) > top_k: + keep = torch.argsort(scores, descending=True)[:top_k] + pred['keypoints0'], pred['keypoints1'] =\ + pred['keypoints0'][keep], pred['keypoints1'][keep] + scores = scores[keep] + + # Switch back indices + pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} + pred['scores'] = scores + del pred['confidence'] + return pred diff --git a/imcui/third_party/gim/hloc/pairs_from_exhaustive.py b/imcui/third_party/gim/hloc/pairs_from_exhaustive.py new file mode 100644 index 0000000000000000000000000000000000000000..9dffbd1d69a1c4e063786413a68555b3cded013d --- /dev/null +++ b/imcui/third_party/gim/hloc/pairs_from_exhaustive.py @@ -0,0 +1,74 @@ +import argparse +import collections.abc as collections +import os +from pathlib import Path +from typing import Optional, Union, List + +from . import logger +from .utils.parsers import parse_image_lists +from .utils.io import list_h5_names + + +def main( + output: Path, + image_list: Optional[Union[Path, List[str]]] = None, + features: Optional[Path] = None, + ref_list: Optional[Union[Path, List[str]]] = None, + ref_features: Optional[Path] = None): + + if image_list is not None: + if isinstance(image_list, (str, Path)): + if image_list.is_dir(): + names_q = [x for x in os.listdir(str(image_list)) if x.endswith('.jpg') or x.endswith('.png')] + names_q.sort() + else: + names_q = parse_image_lists(image_list) + elif isinstance(image_list, collections.Iterable): + names_q = list(image_list) + else: + raise ValueError(f'Unknown type for image list: {image_list}') + elif features is not None: + names_q = list_h5_names(features) + else: + raise ValueError('Provide either a list of images or a feature file.') + + self_matching = False + if ref_list is not None: + if isinstance(ref_list, (str, Path)): + names_ref = parse_image_lists(ref_list) + elif isinstance(image_list, collections.Iterable): + names_ref = list(ref_list) + else: + raise ValueError( + f'Unknown type for reference image list: {ref_list}') + elif ref_features is not None: + names_ref = list_h5_names(ref_features) + else: + self_matching = True + names_ref = names_q + + pairs = [] + for i, n1 in enumerate(names_q): + for j, n2 in enumerate(names_ref): + if self_matching and j <= i: + continue + # if j - i > 5: + # continue + pairs.append((n1, n2)) + + logger.info(f'Found {len(pairs)} pairs.') + with open(output, 'w') as f: + f.write('\n'.join(' '.join([i, j]) for i, j in pairs)) + + return pairs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--output', required=True, type=Path) + parser.add_argument('--image_list', type=Path) + parser.add_argument('--features', type=Path) + parser.add_argument('--ref_list', type=Path) + parser.add_argument('--ref_features', type=Path) + args = parser.parse_args() + main(**args.__dict__) diff --git a/imcui/third_party/gim/hloc/reconstruction.py b/imcui/third_party/gim/hloc/reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..943920541be7f202f3743cced2f8e8fd0f15f184 --- /dev/null +++ b/imcui/third_party/gim/hloc/reconstruction.py @@ -0,0 +1,158 @@ +import argparse +import shutil +from typing import Optional, List, Dict, Any +import multiprocessing +from pathlib import Path +import pycolmap + +from . import logger +from .utils.database import COLMAPDatabase +from .triangulation import ( + import_features, import_matches, estimation_and_geometric_verification, + OutputCapture, parse_option_args) + + +def create_empty_db(database_path: Path): + if database_path.exists(): + logger.warning('The database already exists, deleting it.') + database_path.unlink() + logger.info('Creating an empty database...') + db = COLMAPDatabase.connect(database_path) + db.create_tables() + db.commit() + db.close() + + +def import_images(image_dir: Path, + database_path: Path, + camera_mode: pycolmap.CameraMode, + image_list: Optional[List[str]] = None, + options: Optional[Dict[str, Any]] = None): + logger.info('Importing images into the database...') + if options is None: + options = {} + images = list(image_dir.iterdir()) + if len(images) == 0: + raise IOError(f'No images found in {image_dir}.') + with pycolmap.ostream(): + pycolmap.import_images(database_path, image_dir, camera_mode, + image_list=image_list or [], + options=options) + + +def get_image_ids(database_path: Path) -> Dict[str, int]: + db = COLMAPDatabase.connect(database_path) + images = {} + for name, image_id in db.execute("SELECT name, image_id FROM images;"): + images[name] = image_id + db.close() + return images + + +def run_reconstruction(sfm_dir: Path, + database_path: Path, + image_dir: Path, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, + ) -> pycolmap.Reconstruction: + models_path = sfm_dir / 'models' + models_path.mkdir(exist_ok=True, parents=True) + logger.info('Running 3D reconstruction...') + if options is None: + options = {} + options = {'num_threads': min(multiprocessing.cpu_count(), 16), **options} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstructions = pycolmap.incremental_mapping( + database_path, image_dir, models_path, options=options) + + if len(reconstructions) == 0: + logger.error('Could not reconstruct any model!') + return None + logger.info(f'Reconstructed {len(reconstructions)} model(s).') + + largest_index = None + largest_num_images = 0 + for index, rec in reconstructions.items(): + num_images = rec.num_reg_images() + if num_images > largest_num_images: + largest_index = index + largest_num_images = num_images + assert largest_index is not None + logger.info(f'Largest model is #{largest_index} ' + f'with {largest_num_images} images.') + + for filename in ['images.bin', 'cameras.bin', 'points3D.bin']: + if (sfm_dir / filename).exists(): + (sfm_dir / filename).unlink() + shutil.move( + str(models_path / str(largest_index) / filename), str(sfm_dir)) + return reconstructions[largest_index] + + +def main(sfm_dir: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + camera_mode: pycolmap.CameraMode = pycolmap.CameraMode.AUTO, + verbose: bool = False, + skip_geometric_verification: bool = False, + min_match_score: Optional[float] = None, + image_list: Optional[List[str]] = None, + image_options: Optional[Dict[str, Any]] = None, + mapper_options: Optional[Dict[str, Any]] = None, + ) -> pycolmap.Reconstruction: + + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / 'database.db' + + create_empty_db(database) + import_images(image_dir, database, camera_mode, image_list, image_options) + image_ids = get_image_ids(database) + import_features(image_ids, database, features) + import_matches(image_ids, database, pairs, matches, + min_match_score, skip_geometric_verification) + if not skip_geometric_verification: + estimation_and_geometric_verification(database, pairs, verbose) + reconstruction = run_reconstruction( + sfm_dir, database, image_dir, verbose, mapper_options) + if reconstruction is not None: + logger.info(f'Reconstruction statistics:\n{reconstruction.summary()}' + + f'\n\tnum_input_images = {len(image_ids)}') + return reconstruction + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--sfm_dir', type=Path, required=True) + parser.add_argument('--image_dir', type=Path, required=True) + + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--features', type=Path, required=True) + parser.add_argument('--matches', type=Path, required=True) + + parser.add_argument('--camera_mode', type=str, default="AUTO", + choices=list(pycolmap.CameraMode.__members__.keys())) + parser.add_argument('--skip_geometric_verification', action='store_true') + parser.add_argument('--min_match_score', type=float) + parser.add_argument('--verbose', action='store_true') + + parser.add_argument('--image_options', nargs='+', default=[], + help='List of key=value from {}'.format( + pycolmap.ImageReaderOptions().todict())) + parser.add_argument('--mapper_options', nargs='+', default=[], + help='List of key=value from {}'.format( + pycolmap.IncrementalMapperOptions().todict())) + args = parser.parse_args().__dict__ + + image_options = parse_option_args( + args.pop("image_options"), pycolmap.ImageReaderOptions()) + mapper_options = parse_option_args( + args.pop("mapper_options"), pycolmap.IncrementalMapperOptions()) + + main(**args, image_options=image_options, mapper_options=mapper_options) diff --git a/imcui/third_party/gim/hloc/triangulation.py b/imcui/third_party/gim/hloc/triangulation.py new file mode 100644 index 0000000000000000000000000000000000000000..9a659f3b465bf98346e8e4c840ed74df8fe1e950 --- /dev/null +++ b/imcui/third_party/gim/hloc/triangulation.py @@ -0,0 +1,277 @@ +import argparse +import contextlib +from typing import Optional, List, Dict, Any +import io +import sys +from pathlib import Path +import numpy as np +from tqdm import tqdm +import pycolmap + +from . import logger +from .utils.database import COLMAPDatabase +from .utils.io import get_keypoints, get_matches +from .utils.parsers import parse_retrieval +from .utils.geometry import compute_epipolar_errors + + +class OutputCapture: + def __init__(self, verbose: bool): + self.verbose = verbose + + def __enter__(self): + if not self.verbose: + self.capture = contextlib.redirect_stdout(io.StringIO()) + self.out = self.capture.__enter__() + + def __exit__(self, exc_type, *args): + if not self.verbose: + self.capture.__exit__(exc_type, *args) + if exc_type is not None: + logger.error('Failed with output:\n%s', self.out.getvalue()) + sys.stdout.flush() + + +def create_db_from_model(reconstruction: pycolmap.Reconstruction, + database_path: Path) -> Dict[str, int]: + if database_path.exists(): + logger.warning('The database already exists, deleting it.') + database_path.unlink() + + db = COLMAPDatabase.connect(database_path) + db.create_tables() + + for i, camera in reconstruction.cameras.items(): + db.add_camera( + camera.model_id, camera.width, camera.height, camera.params, + camera_id=i, prior_focal_length=True) + + for i, image in reconstruction.images.items(): + db.add_image(image.name, image.camera_id, image_id=i) + + db.commit() + db.close() + return {image.name: i for i, image in reconstruction.images.items()} + + +def import_features(image_ids: Dict[str, int], + database_path: Path, + features_path: Path): + logger.info('Importing features into the database...') + db = COLMAPDatabase.connect(database_path) + + for image_name, image_id in tqdm(image_ids.items()): + keypoints = get_keypoints(features_path, image_name) + keypoints += 0.5 # COLMAP origin + db.add_keypoints(image_id, keypoints) + + db.commit() + db.close() + + +def import_matches(image_ids: Dict[str, int], + database_path: Path, + pairs_path: Path, + matches_path: Path, + min_match_score: Optional[float] = None, + skip_geometric_verification: bool = False): + logger.info('Importing matches into the database...') + + with open(str(pairs_path), 'r') as f: + pairs = [p.split() for p in f.readlines()] + + db = COLMAPDatabase.connect(database_path) + + matched = set() + for name0, name1 in tqdm(pairs): + id0, id1 = image_ids[name0], image_ids[name1] + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matches, scores = get_matches(matches_path, name0, name1) + if min_match_score: + matches = matches[scores > min_match_score] + db.add_matches(id0, id1, matches) + matched |= {(id0, id1), (id1, id0)} + + if skip_geometric_verification: + db.add_two_view_geometry(id0, id1, matches) + + db.commit() + db.close() + + +def estimation_and_geometric_verification(database_path: Path, + pairs_path: Path, + verbose: bool = False): + logger.info('Performing geometric verification of the matches...') + with OutputCapture(verbose): + with pycolmap.ostream(): + pycolmap.verify_matches( + database_path, pairs_path, + options=dict(ransac=dict(max_num_trials=20000, min_inlier_ratio=0.1)),) + + +def geometric_verification(image_ids: Dict[str, int], + reference: pycolmap.Reconstruction, + database_path: Path, + features_path: Path, + pairs_path: Path, + matches_path: Path, + max_error: float = 4.0): + logger.info('Performing geometric verification of the matches...') + + pairs = parse_retrieval(pairs_path) + db = COLMAPDatabase.connect(database_path) + + inlier_ratios = [] + matched = set() + for name0 in tqdm(pairs): + id0 = image_ids[name0] + image0 = reference.images[id0] + cam0 = reference.cameras[image0.camera_id] + kps0, noise0 = get_keypoints( + features_path, name0, return_uncertainty=True) + noise0 = 1.0 if noise0 is None else noise0 + if len(kps0) > 0: + kps0 = np.stack(cam0.image_to_world(kps0)) + else: + kps0 = np.zeros((0, 2)) + + for name1 in pairs[name0]: + id1 = image_ids[name1] + image1 = reference.images[id1] + cam1 = reference.cameras[image1.camera_id] + kps1, noise1 = get_keypoints( + features_path, name1, return_uncertainty=True) + noise1 = 1.0 if noise1 is None else noise1 + if len(kps1) > 0: + kps1 = np.stack(cam1.image_to_world(kps1)) + else: + kps1 = np.zeros((0, 2)) + + matches = get_matches(matches_path, name0, name1)[0] + + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matched |= {(id0, id1), (id1, id0)} + + if matches.shape[0] == 0: + db.add_two_view_geometry(id0, id1, matches) + continue + + qvec_01, tvec_01 = pycolmap.relative_pose( + image0.qvec, image0.tvec, image1.qvec, image1.tvec) + _, errors0, errors1 = compute_epipolar_errors( + qvec_01, tvec_01, kps0[matches[:, 0]], kps1[matches[:, 1]]) + valid_matches = np.logical_and( + errors0 <= max_error * noise0 / cam0.mean_focal_length(), + errors1 <= max_error * noise1 / cam1.mean_focal_length()) + # TODO: We could also add E to the database, but we need + # to reverse the transformations if id0 > id1 in utils/database.py. + db.add_two_view_geometry(id0, id1, matches[valid_matches, :]) + inlier_ratios.append(np.mean(valid_matches)) + logger.info('mean/med/min/max valid matches %.2f/%.2f/%.2f/%.2f%%.', + np.mean(inlier_ratios) * 100, np.median(inlier_ratios) * 100, + np.min(inlier_ratios) * 100, np.max(inlier_ratios) * 100) + + db.commit() + db.close() + + +def run_triangulation(model_path: Path, + database_path: Path, + image_dir: Path, + reference_model: pycolmap.Reconstruction, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, + ) -> pycolmap.Reconstruction: + model_path.mkdir(parents=True, exist_ok=True) + logger.info('Running 3D triangulation...') + if options is None: + options = {} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstruction = pycolmap.triangulate_points( + reference_model, database_path, image_dir, model_path, + options=options) + return reconstruction + + +def main(sfm_dir: Path, + reference_model: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + skip_geometric_verification: bool = False, + estimate_two_view_geometries: bool = False, + min_match_score: Optional[float] = None, + verbose: bool = False, + mapper_options: Optional[Dict[str, Any]] = None, + ) -> pycolmap.Reconstruction: + + assert reference_model.exists(), reference_model + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / 'database.db' + reference = pycolmap.Reconstruction(reference_model) + + image_ids = create_db_from_model(reference, database) + import_features(image_ids, database, features) + import_matches(image_ids, database, pairs, matches, + min_match_score, skip_geometric_verification) + if not skip_geometric_verification: + if estimate_two_view_geometries: + estimation_and_geometric_verification(database, pairs, verbose) + else: + geometric_verification( + image_ids, reference, database, features, pairs, matches) + reconstruction = run_triangulation(sfm_dir, database, image_dir, reference, + verbose, mapper_options) + logger.info('Finished the triangulation with statistics:\n%s', + reconstruction.summary()) + return reconstruction + + +def parse_option_args(args: List[str], default_options) -> Dict[str, Any]: + options = {} + for arg in args: + idx = arg.find('=') + if idx == -1: + raise ValueError('Options format: key1=value1 key2=value2 etc.') + key, value = arg[:idx], arg[idx+1:] + if not hasattr(default_options, key): + raise ValueError( + f'Unknown option "{key}", allowed options and default values' + f' for {default_options.summary()}') + value = eval(value) + target_type = type(getattr(default_options, key)) + if not isinstance(value, target_type): + raise ValueError(f'Incorrect type for option "{key}":' + f' {type(value)} vs {target_type}') + options[key] = value + return options + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--sfm_dir', type=Path, required=True) + parser.add_argument('--reference_sfm_model', type=Path, required=True) + parser.add_argument('--image_dir', type=Path, required=True) + + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--features', type=Path, required=True) + parser.add_argument('--matches', type=Path, required=True) + + parser.add_argument('--skip_geometric_verification', action='store_true') + parser.add_argument('--min_match_score', type=float) + parser.add_argument('--verbose', action='store_true') + args = parser.parse_args().__dict__ + + mapper_options = parse_option_args( + args.pop("mapper_options"), pycolmap.IncrementalMapperOptions()) + + main(**args, mapper_options=mapper_options) diff --git a/imcui/third_party/gim/hloc/utils/__init__.py b/imcui/third_party/gim/hloc/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dabfebe727522676bcd78645d199ff6e2d40bb3f --- /dev/null +++ b/imcui/third_party/gim/hloc/utils/__init__.py @@ -0,0 +1,49 @@ +import cv2 +import csv +import torch +import numpy as np + +CLS_DICT = {} +with open('weights/object150_info.csv') as f: + reader = csv.reader(f) + next(reader) + for row in reader: + name = row[5].split(";")[0] + if name == 'screen': + name = '_'.join(row[5].split(";")[:2]) + CLS_DICT[name] = int(row[0]) - 1 + +exclude = ['person', 'sky', 'car'] + + +def read_deeplab_image(img, size): + width, height = img.shape[1], img.shape[0] + + if max(width, height) > size: + if width > height: + img = cv2.resize(img, (size, int(size * height / width)), interpolation=cv2.INTER_AREA) + else: + img = cv2.resize(img, (int(size * width / height), size), interpolation=cv2.INTER_AREA) + + img = (torch.from_numpy(img.copy()).float() / 255).permute(2, 0, 1)[None] + + return img + + +def read_segmentation_image(img, size): + img = read_deeplab_image(img, size=size)[0] + # img = (torch.from_numpy(img).float() / 255).permute(2, 0, 1) + img = img - torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) + img = img / torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) + return img + + +def segment(rgb, size, device, segmentation_module): + img_data = read_segmentation_image(rgb, size=size) + singleton_batch = {'img_data': img_data[None].to(device)} + output_size = img_data.shape[1:] + # Run the segmentation at the highest resolution. + scores = segmentation_module(singleton_batch, segSize=output_size) + # Get the predicted scores for each pixel + _, pred = torch.max(scores, dim=1) + return pred.cpu()[0].numpy().astype(np.uint8) diff --git a/imcui/third_party/gim/hloc/utils/base_model.py b/imcui/third_party/gim/hloc/utils/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..caf17f050c5fb675e3d435b4f170243f813484d3 --- /dev/null +++ b/imcui/third_party/gim/hloc/utils/base_model.py @@ -0,0 +1,47 @@ +import sys +from abc import ABCMeta, abstractmethod +from torch import nn +from copy import copy +import inspect + + +class BaseModel(nn.Module, metaclass=ABCMeta): + default_conf = {} + required_inputs = [] + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + self.conf = conf = {**self.default_conf, **conf} + self.required_inputs = copy(self.required_inputs) + self._init(conf) + sys.stdout.flush() + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + for key in self.required_inputs: + assert key in data, 'Missing key {} in data'.format(key) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + +def dynamic_load(root, model): + module_path = f'{root.__name__}.{model}' + module = __import__(module_path, fromlist=['']) + classes = inspect.getmembers(module, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == module_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseModel)] + assert len(classes) == 1, classes + return classes[0][1] + # return getattr(module, 'Model') diff --git a/imcui/third_party/gim/hloc/utils/database.py b/imcui/third_party/gim/hloc/utils/database.py new file mode 100644 index 0000000000000000000000000000000000000000..870a8c4fd43e28beb9c423564b34cb6457b27887 --- /dev/null +++ b/imcui/third_party/gim/hloc/utils/database.py @@ -0,0 +1,360 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +# This script is based on an original implementation by True Price. + +import sys +import sqlite3 +import numpy as np + + +IS_PYTHON3 = sys.version_info[0] >= 3 + +MAX_IMAGE_ID = 2**31 - 1 + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) +""".format(MAX_IMAGE_ID) + +CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ +CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB, + qvec BLOB, + tvec BLOB) +""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) +""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = \ + "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join([ + CREATE_CAMERAS_TABLE, + CREATE_IMAGES_TABLE, + CREATE_KEYPOINTS_TABLE, + CREATE_DESCRIPTORS_TABLE, + CREATE_MATCHES_TABLE, + CREATE_TWO_VIEW_GEOMETRIES_TABLE, + CREATE_NAME_INDEX +]) + + +def image_ids_to_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def pair_id_to_image_ids(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID + return image_id1, image_id2 + + +def array_to_blob(array): + if IS_PYTHON3: + return array.tobytes() + else: + return np.getbuffer(array) + + +def blob_to_array(blob, dtype, shape=(-1,)): + if IS_PYTHON3: + return np.fromstring(blob, dtype=dtype).reshape(*shape) + else: + return np.frombuffer(blob, dtype=dtype).reshape(*shape) + + +class COLMAPDatabase(sqlite3.Connection): + + @staticmethod + def connect(database_path): + return sqlite3.connect(str(database_path), factory=COLMAPDatabase) + + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.create_tables = lambda: self.executescript(CREATE_ALL) + self.create_cameras_table = \ + lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.create_descriptors_table = \ + lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) + self.create_images_table = \ + lambda: self.executescript(CREATE_IMAGES_TABLE) + self.create_two_view_geometries_table = \ + lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) + self.create_keypoints_table = \ + lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.create_matches_table = \ + lambda: self.executescript(CREATE_MATCHES_TABLE) + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + def add_camera(self, model, width, height, params, + prior_focal_length=False, camera_id=None): + params = np.asarray(params, np.float64) + cursor = self.execute( + "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + (camera_id, model, width, height, array_to_blob(params), + prior_focal_length)) + return cursor.lastrowid + + def add_image(self, name, camera_id, + prior_q=np.full(4, np.NaN), prior_t=np.full(3, np.NaN), + image_id=None): + cursor = self.execute( + "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], + prior_q[3], prior_t[0], prior_t[1], prior_t[2])) + return cursor.lastrowid + + def add_keypoints(self, image_id, keypoints): + assert(len(keypoints.shape) == 2) + assert(keypoints.shape[1] in [2, 4, 6]) + + keypoints = np.asarray(keypoints, np.float32) + self.execute( + "INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) + + def add_descriptors(self, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + self.execute( + "INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) + + def add_matches(self, image_id1, image_id2, matches): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + self.execute( + "INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),)) + + def add_two_view_geometry(self, image_id1, image_id2, matches, + F=np.eye(3), E=np.eye(3), H=np.eye(3), + qvec=np.array([1.0, 0.0, 0.0, 0.0]), + tvec=np.zeros(3), config=2): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + F = np.asarray(F, dtype=np.float64) + E = np.asarray(E, dtype=np.float64) + H = np.asarray(H, dtype=np.float64) + qvec = np.asarray(qvec, dtype=np.float64) + tvec = np.asarray(tvec, dtype=np.float64) + self.execute( + "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches), config, + array_to_blob(F), array_to_blob(E), array_to_blob(H), + array_to_blob(qvec), array_to_blob(tvec))) + + +def example_usage(): + import os + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--database_path", default="database.db") + args = parser.parse_args() + + if os.path.exists(args.database_path): + print("ERROR: database path already exists -- will not modify it.") + return + + # Open the database. + + db = COLMAPDatabase.connect(args.database_path) + + # For convenience, try creating all the tables upfront. + + db.create_tables() + + # Create dummy cameras. + + model1, width1, height1, params1 = \ + 0, 1024, 768, np.array((1024., 512., 384.)) + model2, width2, height2, params2 = \ + 2, 1024, 768, np.array((1024., 512., 384., 0.1)) + + camera_id1 = db.add_camera(model1, width1, height1, params1) + camera_id2 = db.add_camera(model2, width2, height2, params2) + + # Create dummy images. + + image_id1 = db.add_image("image1.png", camera_id1) + image_id2 = db.add_image("image2.png", camera_id1) + image_id3 = db.add_image("image3.png", camera_id2) + image_id4 = db.add_image("image4.png", camera_id2) + + # Create dummy keypoints. + # + # Note that COLMAP supports: + # - 2D keypoints: (x, y) + # - 4D keypoints: (x, y, theta, scale) + # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) + + num_keypoints = 1000 + keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) + keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) + + db.add_keypoints(image_id1, keypoints1) + db.add_keypoints(image_id2, keypoints2) + db.add_keypoints(image_id3, keypoints3) + db.add_keypoints(image_id4, keypoints4) + + # Create dummy matches. + + M = 50 + matches12 = np.random.randint(num_keypoints, size=(M, 2)) + matches23 = np.random.randint(num_keypoints, size=(M, 2)) + matches34 = np.random.randint(num_keypoints, size=(M, 2)) + + db.add_matches(image_id1, image_id2, matches12) + db.add_matches(image_id2, image_id3, matches23) + db.add_matches(image_id3, image_id4, matches34) + + # Commit the data to the file. + + db.commit() + + # Read and check cameras. + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id1 + assert model == model1 and width == width1 and height == height1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id2 + assert model == model2 and width == width2 and height == height2 + assert np.allclose(params, params2) + + # Read and check keypoints. + + keypoints = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute( + "SELECT image_id, data FROM keypoints")) + + assert np.allclose(keypoints[image_id1], keypoints1) + assert np.allclose(keypoints[image_id2], keypoints2) + assert np.allclose(keypoints[image_id3], keypoints3) + assert np.allclose(keypoints[image_id4], keypoints4) + + # Read and check matches. + + pair_ids = [image_ids_to_pair_id(*pair) for pair in + ((image_id1, image_id2), + (image_id2, image_id3), + (image_id3, image_id4))] + + matches = dict( + (pair_id_to_image_ids(pair_id), + blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches") + ) + + assert np.all(matches[(image_id1, image_id2)] == matches12) + assert np.all(matches[(image_id2, image_id3)] == matches23) + assert np.all(matches[(image_id3, image_id4)] == matches34) + + # Clean up. + + db.close() + + if os.path.exists(args.database_path): + os.remove(args.database_path) + + +if __name__ == "__main__": + example_usage() diff --git a/imcui/third_party/gim/hloc/utils/geometry.py b/imcui/third_party/gim/hloc/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..7f5ce101d463da35d8d661de083ff9eabcbc5f76 --- /dev/null +++ b/imcui/third_party/gim/hloc/utils/geometry.py @@ -0,0 +1,37 @@ +import numpy as np +import pycolmap + + +def to_homogeneous(p): + return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1) + + +def vector_to_cross_product_matrix(v): + return np.array([ + [0, -v[2], v[1]], + [v[2], 0, -v[0]], + [-v[1], v[0], 0] + ]) + + +def compute_epipolar_errors(qvec_r2t, tvec_r2t, p2d_r, p2d_t): + T_r2t = pose_matrix_from_qvec_tvec(qvec_r2t, tvec_r2t) + # Compute errors in normalized plane to avoid distortion. + E = vector_to_cross_product_matrix(T_r2t[: 3, -1]) @ T_r2t[: 3, : 3] + l2d_r2t = (E @ to_homogeneous(p2d_r).T).T + l2d_t2r = (E.T @ to_homogeneous(p2d_t).T).T + errors_r = ( + np.abs(np.sum(to_homogeneous(p2d_r) * l2d_t2r, axis=1)) / + np.linalg.norm(l2d_t2r[:, : 2], axis=1)) + errors_t = ( + np.abs(np.sum(to_homogeneous(p2d_t) * l2d_r2t, axis=1)) / + np.linalg.norm(l2d_r2t[:, : 2], axis=1)) + return E, errors_r, errors_t + + +def pose_matrix_from_qvec_tvec(qvec, tvec): + pose = np.zeros((4, 4)) + pose[: 3, : 3] = pycolmap.qvec_to_rotmat(qvec) + pose[: 3, -1] = tvec + pose[-1, -1] = 1 + return pose diff --git a/imcui/third_party/gim/hloc/utils/io.py b/imcui/third_party/gim/hloc/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..92958e9643f172664f06b6c45b0b078347952863 --- /dev/null +++ b/imcui/third_party/gim/hloc/utils/io.py @@ -0,0 +1,73 @@ +from typing import Tuple +from pathlib import Path +import numpy as np +import cv2 +import h5py + +from .parsers import names_to_pair, names_to_pair_old + + +def read_image(path, grayscale=False): + if grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise ValueError(f'Cannot read image {path}.') + if not grayscale and len(image.shape) == 3: + image = image[:, :, ::-1] # BGR to RGB + return image + + +def list_h5_names(path): + names = [] + with h5py.File(str(path), 'r', libver='latest') as fd: + def visit_fn(_, obj): + if isinstance(obj, h5py.Dataset): + names.append(obj.parent.name.strip('/')) + fd.visititems(visit_fn) + return list(set(names)) + + +def get_keypoints(path: Path, name: str, + return_uncertainty: bool = False) -> np.ndarray: + with h5py.File(str(path), 'r', libver='latest') as hfile: + dset = hfile[name]['keypoints'] + p = dset.__array__() + uncertainty = dset.attrs.get('uncertainty') + if return_uncertainty: + return p, uncertainty + return p + + +def find_pair(hfile: h5py.File, name0: str, name1: str): + pair = names_to_pair(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair(name1, name0) + if pair in hfile: + return pair, True + # older, less efficient format + pair = names_to_pair_old(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair_old(name1, name0) + if pair in hfile: + return pair, True + raise ValueError( + f'Could not find pair {(name0, name1)}... ' + 'Maybe you matched with a different list of pairs? ') + + +def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]: + with h5py.File(str(path), 'r', libver='latest') as hfile: + pair, reverse = find_pair(hfile, name0, name1) + matches = hfile[pair]['matches0'].__array__() + scores = hfile[pair]['matching_scores0'].__array__() + idx = np.where(matches != -1)[0] + matches = np.stack([idx, matches[idx]], -1) + if reverse: + matches = np.flip(matches, -1) + scores = scores[idx] + return matches, scores diff --git a/imcui/third_party/gim/hloc/utils/parsers.py b/imcui/third_party/gim/hloc/utils/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4d9c194c28bded8906ea7ffca980a71271d59c --- /dev/null +++ b/imcui/third_party/gim/hloc/utils/parsers.py @@ -0,0 +1,56 @@ +from pathlib import Path +import logging +import numpy as np +from collections import defaultdict +import pycolmap + +logger = logging.getLogger(__name__) + + +def parse_image_list(path, with_intrinsics=False): + images = [] + with open(path, 'r') as f: + for line in f: + line = line.strip('\n') + if len(line) == 0 or line[0] == '#': + continue + name, *data = line.split() + if with_intrinsics: + model, width, height, *params = data + params = np.array(params, float) + cam = pycolmap.Camera(model, int(width), int(height), params) + images.append((name, cam)) + else: + images.append(name) + + assert len(images) > 0 + logger.info(f'Imported {len(images)} images from {path.name}') + return images + + +def parse_image_lists(paths, with_intrinsics=False): + images = [] + files = list(Path(paths.parent).glob(paths.name)) + assert len(files) > 0 + for lfile in files: + images += parse_image_list(lfile, with_intrinsics=with_intrinsics) + return images + + +def parse_retrieval(path): + retrieval = defaultdict(list) + with open(path, 'r') as f: + for p in f.read().rstrip('\n').split('\n'): + if len(p) == 0: + continue + q, r = p.split() + retrieval[q].append(r) + return dict(retrieval) + + +def names_to_pair(name0, name1, separator='/'): + return separator.join((name0.replace('/', '-'), name1.replace('/', '-'))) + + +def names_to_pair_old(name0, name1): + return names_to_pair(name0, name1, separator='_') diff --git a/imcui/third_party/gim/networks/__init__.py b/imcui/third_party/gim/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/imcui/third_party/gim/networks/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/imcui/third_party/gim/networks/dkm/__init__.py b/imcui/third_party/gim/networks/dkm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b47632780acc7762bcccc348e2025fe99f3726 --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/__init__.py @@ -0,0 +1,4 @@ +from .models import ( + DKMv3_outdoor, + DKMv3_indoor, + ) diff --git a/imcui/third_party/gim/networks/dkm/models/__init__.py b/imcui/third_party/gim/networks/dkm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4fc321ec70fd116beca23e94248cb6bbe771523 --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/models/__init__.py @@ -0,0 +1,4 @@ +from .model_zoo import ( + DKMv3_outdoor, + DKMv3_indoor, +) diff --git a/imcui/third_party/gim/networks/dkm/models/dkm.py b/imcui/third_party/gim/networks/dkm/models/dkm.py new file mode 100644 index 0000000000000000000000000000000000000000..62fbb9a1000995d940ba816ab9c9c5bf9b5d0895 --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/models/dkm.py @@ -0,0 +1,752 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from networks.dkm.utils import get_tuple_transform_ops +from einops import rearrange +from networks.dkm.utils.local_correlation import local_correlation +from networks.dkm.utils.kde import kde + + +class ConvRefiner(nn.Module): + def __init__( + self, + in_dim=6, + hidden_dim=16, + out_dim=2, + dw=False, + kernel_size=5, + hidden_blocks=3, + displacement_emb = None, + displacement_emb_dim = None, + local_corr_radius = None, + corr_in_other = None, + no_support_fm = False, + ): + super().__init__() + self.block1 = self.create_block( + in_dim, hidden_dim, dw=dw, kernel_size=kernel_size + ) + self.hidden_blocks = nn.Sequential( + *[ + self.create_block( + hidden_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + ) + for hb in range(hidden_blocks) + ] + ) + self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) + if displacement_emb: + self.has_displacement_emb = True + self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + else: + self.has_displacement_emb = False + self.local_corr_radius = local_corr_radius + self.corr_in_other = corr_in_other + self.no_support_fm = no_support_fm + def create_block( + self, + in_dim, + out_dim, + dw=False, + kernel_size=5, + ): + num_groups = 1 if not dw else in_dim + if dw: + assert ( + out_dim % in_dim == 0 + ), "outdim must be divisible by indim for depthwise" + conv1 = nn.Conv2d( + in_dim, + out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=num_groups, + ) + norm = nn.BatchNorm2d(out_dim) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) + return nn.Sequential(conv1, norm, relu, conv2) + + def forward(self, x, y, flow): + """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them + + Args: + x ([type]): [description] + y ([type]): [description] + flow ([type]): [description] + + Returns: + [type]: [description] + """ + device = x.device + b,c,hs,ws = x.shape + with torch.no_grad(): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) + if self.has_displacement_emb: + query_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ) + ) + query_coords = torch.stack((query_coords[1], query_coords[0])) + query_coords = query_coords[None].expand(b, 2, hs, ws) + in_displacement = flow-query_coords + emb_in_displacement = self.disp_emb(in_displacement) + if self.local_corr_radius: + #TODO: should corr have gradient? + if self.corr_in_other: + # Corr in other means take a kxk grid around the predicted coordinate in other image + local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow) + else: + # Otherwise we use the warp to sample in the first image + # This is actually different operations, especially for large viewpoint changes + local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,) + if self.no_support_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) + else: + d = torch.cat((x, x_hat, emb_in_displacement), dim=1) + else: + if self.no_support_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat), dim=1) + d = self.block1(d) + d = self.hidden_blocks(d) + d = self.out_conv(d) + certainty, displacement = d[:, :-2], d[:, -2:] + return certainty, displacement + + +class CosKernel(nn.Module): # similar to softmax kernel + def __init__(self, T, learn_temperature=False): + super().__init__() + self.learn_temperature = learn_temperature + if self.learn_temperature: + self.T = nn.Parameter(torch.tensor(T)) + else: + self.T = T + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + if self.learn_temperature: + T = self.T.abs() + 0.01 + else: + T = torch.tensor(self.T, device=c.device) + K = ((c - 1.0) / T).exp() + return K + + +class CAB(nn.Module): + def __init__(self, in_channels, out_channels): + super(CAB, self).__init__() + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.sigmod = nn.Sigmoid() + + def forward(self, x): + x1, x2 = x # high, low (old, new) + x = torch.cat([x1, x2], dim=1) + x = self.global_pooling(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.sigmod(x) + x2 = x * x2 + res = x2 + x1 + return res + + +class RRB(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(RRB, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + ) + self.relu = nn.ReLU() + self.bn = nn.BatchNorm2d(out_channels) + self.conv3 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + ) + + def forward(self, x): + x = self.conv1(x) + res = self.conv2(x) + res = self.bn(res) + res = self.relu(res) + res = self.conv3(res) + return self.relu(x + res) + + +class DFN(nn.Module): + def __init__( + self, + internal_dim, + feat_input_modules, + pred_input_modules, + rrb_d_dict, + cab_dict, + rrb_u_dict, + use_global_context=False, + global_dim=None, + terminal_module=None, + upsample_mode="bilinear", + align_corners=False, + ): + super().__init__() + if use_global_context: + assert ( + global_dim is not None + ), "Global dim must be provided when using global context" + self.align_corners = align_corners + self.internal_dim = internal_dim + self.feat_input_modules = feat_input_modules + self.pred_input_modules = pred_input_modules + self.rrb_d = rrb_d_dict + self.cab = cab_dict + self.rrb_u = rrb_u_dict + self.use_global_context = use_global_context + if use_global_context: + self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.terminal_module = ( + terminal_module if terminal_module is not None else nn.Identity() + ) + self.upsample_mode = upsample_mode + self._scales = [int(key) for key in self.terminal_module.keys()] + + def scales(self): + return self._scales.copy() + + def forward(self, embeddings, feats, context, key): + feats = self.feat_input_modules[str(key)](feats) + embeddings = torch.cat([feats, embeddings], dim=1) + embeddings = self.rrb_d[str(key)](embeddings) + context = self.cab[str(key)]([context, embeddings]) + context = self.rrb_u[str(key)](context) + preds = self.terminal_module[str(key)](context) + pred_coord = preds[:, -2:] + pred_certainty = preds[:, :-2] + return pred_coord, pred_certainty, context + + +class GP(nn.Module): + def __init__( + self, + kernel, + T=1, + learn_temperature=False, + only_attention=False, + gp_dim=64, + basis="fourier", + covar_size=5, + only_nearest_neighbour=False, + sigma_noise=0.1, + no_cov=False, + predict_features = False, + ): + super().__init__() + self.K = kernel(T=T, learn_temperature=learn_temperature) + self.sigma_noise = sigma_noise + self.covar_size = covar_size + self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) + self.only_attention = only_attention + self.only_nearest_neighbour = only_nearest_neighbour + self.basis = basis + self.no_cov = no_cov + self.dim = gp_dim + self.predict_features = predict_features + + def get_local_cov(self, cov): + K = self.covar_size + b, h, w, h, w = cov.shape + hw = h * w + cov = F.pad(cov, 4 * (K // 2,)) # pad v_q + delta = torch.stack( + torch.meshgrid( + torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) + ), + dim=-1, + ) + positions = torch.stack( + torch.meshgrid( + torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) + ), + dim=-1, + ) + neighbours = positions[:, :, None, None, :] + delta[None, :, :] + points = torch.arange(hw)[:, None].expand(hw, K**2) + local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ + :, + points.flatten(), + neighbours[..., 0].flatten(), + neighbours[..., 1].flatten(), + ].reshape(b, h, w, K**2) + return local_cov + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def project_to_basis(self, x): + if self.basis == "fourier": + return torch.cos(8 * math.pi * self.pos_conv(x)) + elif self.basis == "linear": + return self.pos_conv(x) + else: + raise ValueError( + "No other bases other than fourier and linear currently supported in public release" + ) + + def get_pos_enc(self, y): + b, c, h, w = y.shape + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), + ) + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.project_to_basis(coarse_coords) + return coarse_embedded_coords + + def forward(self, x, y, **kwargs): + b, c, h1, w1 = x.shape + b, c, h2, w2 = y.shape + f = self.get_pos_enc(y) + if self.predict_features: + f = f + y[:,:self.dim] # Stupid way to predict features + b, d, h2, w2 = f.shape + #assert x.shape == y.shape + x, y, f = self.reshape(x), self.reshape(y), self.reshape(f) + K_xx = self.K(x, x) + K_yy = self.K(y, y) + K_xy = self.K(x, y) + K_yx = K_xy.permute(0, 2, 1) + sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] + # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large + if len(K_yy[0]) > 2000: + K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)]) + else: + K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) + + mu_x = K_xy.matmul(K_yy_inv.matmul(f)) + mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) + if not self.no_cov: + cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) + cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + local_cov_x = self.get_local_cov(cov_x) + local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") + gp_feats = torch.cat((mu_x, local_cov_x), dim=1) + else: + gp_feats = mu_x + return gp_feats + + +class Encoder(nn.Module): + def __init__(self, resnet): + super().__init__() + self.resnet = resnet + def forward(self, x): + x0 = x + b, c, h, w = x.shape + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x1 = self.resnet.relu(x) + + x = self.resnet.maxpool(x1) + x2 = self.resnet.layer1(x) + + x3 = self.resnet.layer2(x2) + + x4 = self.resnet.layer3(x3) + + x5 = self.resnet.layer4(x4) + feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0} + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + +class Decoder(nn.Module): + def __init__( + self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None, + ): + super().__init__() + self.embedding_decoder = embedding_decoder + self.gps = gps + self.proj = proj + self.conv_refiner = conv_refiner + self.detach = detach + if scales == "all": + self.scales = ["32", "16", "8", "4", "2", "1"] + else: + self.scales = scales + + def upsample_preds(self, flow, certainty, query, support): + b, hs, ws, d = flow.shape + b, c, h, w = query.shape + flow = flow.permute(0, 3, 1, 2) + certainty = F.interpolate( + certainty, size=(h, w), align_corners=False, mode="bilinear" + ) + flow = F.interpolate( + flow, size=(h, w), align_corners=False, mode="bilinear" + ) + delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow) + flow = torch.stack( + ( + flow[:, 0] + delta_flow[:, 0] / (4 * w), + flow[:, 1] + delta_flow[:, 1] / (4 * h), + ), + dim=1, + ) + flow = flow.permute(0, 2, 3, 1) + certainty = certainty + delta_certainty + return flow, certainty + + def get_placeholder_flow(self, b, h, w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + return coarse_coords + + + def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None): + coarse_scales = self.embedding_decoder.scales() + all_scales = self.scales if not upsample else ["8", "4", "2", "1"] + sizes = {scale: f1[scale].shape[-2:] for scale in f1} + h, w = sizes[1] + b = f1[1].shape[0] + device = f1[1].device + coarsest_scale = int(all_scales[0]) + old_stuff = torch.zeros( + b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + ) + dense_corresps = {} + if not upsample: + dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) + dense_certainty = 0.0 + else: + dense_flow = F.interpolate( + dense_flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + dense_certainty = F.interpolate( + dense_certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + for new_scale in all_scales: + ins = int(new_scale) + f1_s, f2_s = f1[ins], f2[ins] + if new_scale in self.proj: + f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) + b, c, hs, ws = f1_s.shape + if ins in coarse_scales: + old_stuff = F.interpolate( + old_stuff, size=sizes[ins], mode="bilinear", align_corners=False + ) + new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow) + dense_flow, dense_certainty, old_stuff = self.embedding_decoder( + new_stuff, f1_s, old_stuff, new_scale + ) + + if new_scale in self.conv_refiner: + delta_certainty, displacement = self.conv_refiner[new_scale]( + f1_s, f2_s, dense_flow + ) + dense_flow = torch.stack( + ( + dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w), + dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h), + ), + dim=1, + ) + dense_certainty = ( + dense_certainty + delta_certainty + ) # predict both certainty and displacement + + dense_corresps[ins] = { + "dense_flow": dense_flow, + "dense_certainty": dense_certainty, + } + + if new_scale != "1": + dense_flow = F.interpolate( + dense_flow, + size=sizes[ins // 2], + align_corners=False, + mode="bilinear", + ) + + dense_certainty = F.interpolate( + dense_certainty, + size=sizes[ins // 2], + align_corners=False, + mode="bilinear", + ) + if self.detach: + dense_flow = dense_flow.detach() + dense_certainty = dense_certainty.detach() + return dense_corresps + + +class RegressionMatcher(nn.Module): + def __init__( + self, + encoder, + decoder, + h=384, + w=512, + use_contrastive_loss = False, + alpha = 1, + beta = 0, + sample_mode = "threshold", + upsample_preds = True, + symmetric = False, + name = None, + use_soft_mutual_nearest_neighbours = False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.w_resized = w + self.h_resized = h + self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) + self.use_contrastive_loss = use_contrastive_loss + self.alpha = alpha + self.beta = beta + self.sample_mode = sample_mode + self.upsample_preds = upsample_preds + self.symmetric = symmetric + self.name = name + self.sample_thresh = 0.05 + self.upsample_res = (1152, 1536) + if use_soft_mutual_nearest_neighbours: + assert symmetric, "MNS requires symmetric inference" + self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours + + def extract_backbone_features(self, batch, batched = True, upsample = True): + #TODO: only extract stride [1,2,4,8] for upsample = True + x_q = batch["query"] + x_s = batch["support"] + if batched: + X = torch.cat((x_q, x_s)) + feature_pyramid = self.encoder(X) + else: + feature_pyramid = self.encoder(x_q), self.encoder(x_s) + return feature_pyramid + + def sample( + self, + dense_matches, + dense_certainty, + num=10000, + ): + if "threshold" in self.sample_mode: + upper_thresh = self.sample_thresh + dense_certainty = dense_certainty.clone() + dense_certainty_ = dense_certainty.clone() + dense_certainty[dense_certainty > upper_thresh] = 1 + elif "pow" in self.sample_mode: + dense_certainty = dense_certainty**(1/3) + elif "naive" in self.sample_mode: + dense_certainty = torch.ones_like(dense_certainty) + matches, certainty = ( + dense_matches.reshape(-1, 4), + dense_certainty.reshape(-1), + ) + certainty_ = dense_certainty_.reshape(-1) + expansion_factor = 4 if "balanced" in self.sample_mode else 1 + if not certainty.sum(): certainty = certainty + 1e-8 + good_samples = torch.multinomial(certainty, + num_samples = min(expansion_factor*num, len(certainty)), + replacement=False) + good_matches, good_certainty = matches[good_samples], certainty[good_samples] + good_certainty_ = certainty_[good_samples] + good_certainty = good_certainty_ + if "balanced" not in self.sample_mode: + return good_matches, good_certainty + + density = kde(good_matches, std=0.1, device=dense_matches.device) + p = 1 / (density+1) + p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial(p, + num_samples = min(num,len(good_certainty)), + replacement=False) + return good_matches[balanced_samples], good_certainty[balanced_samples] + + def forward(self, batch, batched = True): + feature_pyramid = self.extract_backbone_features(batch, batched=batched) + if batched: + f_q_pyramid = { + scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() + } + f_s_pyramid = { + scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() + } + else: + f_q_pyramid, f_s_pyramid = feature_pyramid + dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid) + if self.training and self.use_contrastive_loss: + return dense_corresps, (f_q_pyramid, f_s_pyramid) + else: + return dense_corresps + + def forward_symmetric(self, batch, upsample = False, batched = True): + feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched) + f_q_pyramid = feature_pyramid + f_s_pyramid = { + scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0])) + for scale, f_scale in feature_pyramid.items() + } + dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {})) + return dense_corresps + + def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): + kpts_A, kpts_B = matches[...,:2], matches[...,2:] + kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1) + kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1) + return kpts_A, kpts_B + + def match( + self, + im1_path, + im2_path, + *args, + batched=False, + ): + assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False " + symmetric = self.symmetric + self.train(False) + with torch.no_grad(): + if not batched: + b = 1 + ws = self.w_resized + hs = self.h_resized + query = F.interpolate(im1_path, size=(hs, ws), mode='bilinear', align_corners=False) + support = F.interpolate(im2_path, size=(hs, ws), mode='bilinear', align_corners=False) + batch = {"query": query, "support": support} + else: + b, c, h, w = im1_path.shape + b, c, h2, w2 = im2_path.shape + assert w == w2 and h == h2, "For batched images we assume same size" + batch = {"query": im1_path, "support": im2_path} + hs, ws = self.h_resized, self.w_resized + finest_scale = 1 + # Run matcher + if symmetric: + dense_corresps = self.forward_symmetric(batch, batched = True) + else: + dense_corresps = self.forward(batch, batched = True) + + if self.upsample_preds: + hs, ws = self.upsample_res + low_res_certainty = F.interpolate( + dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + cert_clamp = 0 + factor = 0.5 + low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + + if self.upsample_preds: + query = F.interpolate(im1_path, size=(hs, ws), mode='bilinear', align_corners=False) + support = F.interpolate(im2_path, size=(hs, ws), mode='bilinear', align_corners=False) + batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]} + if symmetric: + dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True) + else: + dense_corresps = self.forward(batch, batched = True, upsample=True) + query_to_support = dense_corresps[finest_scale]["dense_flow"] + dense_certainty = dense_corresps[finest_scale]["dense_certainty"] + + # Get certainty interpolation + dense_certainty = dense_certainty - low_res_certainty + query_to_support = query_to_support.permute( + 0, 2, 3, 1 + ) + # Create im1 meshgrid + query_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=im1_path.device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=im1_path.device), + ) + ) + query_coords = torch.stack((query_coords[1], query_coords[0])) + query_coords = query_coords[None].expand(b, 2, hs, ws) + dense_certainty = dense_certainty.sigmoid() # logits -> probs + query_coords = query_coords.permute(0, 2, 3, 1) + if (query_to_support.abs() > 1).sum() > 0 and True: + wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0 + dense_certainty[wrong[:,None]] = 0 + # remove black pixels + black_mask1 = (im1_path[0, 0] < 0.03125) & (im1_path[0, 1] < 0.03125) & (im1_path[0, 2] < 0.03125) + black_mask2 = (im2_path[0, 0] < 0.03125) & (im2_path[0, 1] < 0.03125) & (im2_path[0, 2] < 0.03125) + black_mask1 = F.interpolate(black_mask1.float()[None, None], size=tuple(dense_certainty.shape[-2:]), mode='nearest').bool() + black_mask2 = F.interpolate(black_mask2.float()[None, None], size=tuple(dense_certainty.shape[-2:]), mode='nearest').bool() + black_mask = torch.cat((black_mask1, black_mask2), dim=0) + dense_certainty[black_mask] = 0 + + query_to_support = torch.clamp(query_to_support, -1, 1) + if symmetric: + support_coords = query_coords + qts, stq = query_to_support.chunk(2) + q_warp = torch.cat((query_coords, qts), dim=-1) + s_warp = torch.cat((stq, support_coords), dim=-1) + warp = torch.cat((q_warp, s_warp),dim=2) + dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0] + else: + warp = torch.cat((query_coords, query_to_support), dim=-1) + if batched: + return ( + warp, + dense_certainty + ) + else: + return ( + warp[0], + dense_certainty[0], + ) diff --git a/imcui/third_party/gim/networks/dkm/models/encoders.py b/imcui/third_party/gim/networks/dkm/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..6515823a6a7b724fb309850925d42a2389d08f3e --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/models/encoders.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as tvm + +class ResNet18(nn.Module): + def __init__(self, pretrained=False) -> None: + super().__init__() + self.net = tvm.resnet18(pretrained=pretrained) + def forward(self, x): + self = self.net + x1 = x + x = self.conv1(x1) + x = self.bn1(x) + x2 = self.relu(x) + x = self.maxpool(x2) + x4 = self.layer1(x) + x8 = self.layer2(x4) + x16 = self.layer3(x8) + x32 = self.layer4(x16) + return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1} + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + +class ResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None: + super().__init__() + if dilation is None: + dilation = [False,False,False] + if anti_aliased: + pass + else: + if weights is not None: + self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + else: + self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) + + del self.net.fc + self.high_res = high_res + self.freeze_bn = freeze_bn + def forward(self, x): + net = self.net + feats = {1:x} + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x + x = net.layer2(x) + feats[8] = x + x = net.layer3(x) + feats[16] = x + x = net.layer4(x) + feats[32] = x + return feats + + def train(self, mode=True): + super().train(mode) + if self.freeze_bn: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + + + +class ResNet101(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + super().__init__() + if weights is not None: + self.net = tvm.resnet101(weights = weights) + else: + self.net = tvm.resnet101(pretrained=pretrained) + self.high_res = high_res + self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): + net = self.net + feats = {1:x} + sf = self.scale_factor + if self.high_res: + x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer2(x) + feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer3(x) + feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer4(x) + feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + + +class WideResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None) -> None: + super().__init__() + if weights is not None: + self.net = tvm.wide_resnet50_2(weights = weights) + else: + self.net = tvm.wide_resnet50_2(pretrained=pretrained) + self.high_res = high_res + self.scale_factor = 1 if not high_res else 1.5 + def forward(self, x): + net = self.net + feats = {1:x} + sf = self.scale_factor + if self.high_res: + x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.maxpool(x) + x = net.layer1(x) + feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer2(x) + feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer3(x) + feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + x = net.layer4(x) + feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") + return feats + + def train(self, mode=True): + super().train(mode) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass \ No newline at end of file diff --git a/imcui/third_party/gim/networks/dkm/models/model_zoo/DKMv3.py b/imcui/third_party/gim/networks/dkm/models/model_zoo/DKMv3.py new file mode 100644 index 0000000000000000000000000000000000000000..ab527fa25c2fd39f755398a7d891e45e39fc8774 --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/models/model_zoo/DKMv3.py @@ -0,0 +1,145 @@ +from networks.dkm.models.dkm import * +from networks.dkm.models.encoders import * + + +def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", **kwargs): + gp_dim = 256 + dfn_dim = 384 + feat_dim = 256 + coordinate_decoder = DFN( + internal_dim=dfn_dim, + feat_input_modules=nn.ModuleDict( + { + "32": nn.Conv2d(512, feat_dim, 1, 1), + "16": nn.Conv2d(512, feat_dim, 1, 1), + } + ), + pred_input_modules=nn.ModuleDict( + { + "32": nn.Identity(), + "16": nn.Identity(), + } + ), + rrb_d_dict=nn.ModuleDict( + { + "32": RRB(gp_dim + feat_dim, dfn_dim), + "16": RRB(gp_dim + feat_dim, dfn_dim), + } + ), + cab_dict=nn.ModuleDict( + { + "32": CAB(2 * dfn_dim, dfn_dim), + "16": CAB(2 * dfn_dim, dfn_dim), + } + ), + rrb_u_dict=nn.ModuleDict( + { + "32": RRB(dfn_dim, dfn_dim), + "16": RRB(dfn_dim, dfn_dim), + } + ), + terminal_module=nn.ModuleDict( + { + "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), + } + ), + ) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128+(2*7+1)**2, + 2 * 512+128+(2*7+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + ), + "8": ConvRefiner( + 2 * 512+64+(2*3+1)**2, + 2 * 512+64+(2*3+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + ), + "4": ConvRefiner( + 2 * 256+32+(2*2+1)**2, + 2 * 256+32+(2*2+1)**2, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + ), + "1": ConvRefiner( + 2 * 3+6, + 24, + 3, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=6, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp32 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"32": gp32, "16": gp16}) + proj = nn.ModuleDict( + {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} + ) + decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) + + encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs) + # res = matcher.load_state_dict(weights) + return matcher diff --git a/imcui/third_party/gim/networks/dkm/models/model_zoo/__init__.py b/imcui/third_party/gim/networks/dkm/models/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..532b94f7487a6a5a55e429a59261882416a16cfc --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/models/model_zoo/__init__.py @@ -0,0 +1,39 @@ +weight_urls = { + "DKMv3": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth", + }, +} +import torch +from .DKMv3 import DKMv3 + + +def DKMv3_outdoor(path_to_weights = None, device=None): + """ + Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default + resolution can be changed by setting model.h_resized, model.w_resized later. + Additionally upsamples preds to fixed resolution of (864, 1152), + can be turned off by model.upsample_preds = False + """ + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if path_to_weights is not None: + weights = torch.load(path_to_weights, map_location=device) + else: + weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"], + map_location=device) + return DKMv3(weights, 540, 720, upsample_preds = True, device=device) + +def DKMv3_indoor(path_to_weights = None, device=None): + """ + Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default + Resolution can be changed by setting model.h_resized, model.w_resized later. + """ + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if path_to_weights is not None: + weights = torch.load(path_to_weights, map_location=device) + else: + weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"], + map_location=device) + return DKMv3(weights, 480, 640, upsample_preds = False, device=device) diff --git a/imcui/third_party/gim/networks/dkm/utils/__init__.py b/imcui/third_party/gim/networks/dkm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05367ac9521664992f587738caa231f32ae2e81c --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/utils/__init__.py @@ -0,0 +1,13 @@ +from .utils import ( + pose_auc, + get_pose, + compute_relative_pose, + compute_pose_error, + estimate_pose, + rotate_intrinsic, + get_tuple_transform_ops, + get_depth_tuple_transform_ops, + warp_kpts, + numpy_to_pil, + tensor_to_pil, +) diff --git a/imcui/third_party/gim/networks/dkm/utils/kde.py b/imcui/third_party/gim/networks/dkm/utils/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..fa392455e70fda4c9c77c28bda76bcb7ef9045b0 --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/utils/kde.py @@ -0,0 +1,26 @@ +import torch +import torch.nn.functional as F +import numpy as np + +def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1): + raise NotImplementedError("WIP, use at your own risk.") + # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid + x = x.permute(0,3,1,2) + B,C,H,W = x.shape + K = kernel_size ** 2 + unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W) + scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp() + density = scores.sum(dim=1) + return density + + +def kde(x, std = 0.1, device=None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + # use a gaussian kernel to estimate density + x = x.to(device) + scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + density = scores.sum(dim=-1) + return density diff --git a/imcui/third_party/gim/networks/dkm/utils/local_correlation.py b/imcui/third_party/gim/networks/dkm/utils/local_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c1c06291d0b760376a2b2162bcf49d6eb1303c --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/utils/local_correlation.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F + + +def local_correlation( + feature0, + feature1, + local_radius, + padding_mode="zeros", + flow = None +): + device = feature0.device + b, c, h, w = feature0.size() + if flow is None: + # If flow is None, assume feature0 and feature1 are aligned + coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + )) + coords = torch.stack((coords[1], coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + else: + coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + r = local_radius + local_window = torch.meshgrid( + ( + torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device), + torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device), + )) + local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ + None + ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2) + coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2) + window_feature = F.grid_sample( + feature1, coords, padding_mode=padding_mode, align_corners=False + )[...,None].reshape(b,c,h,w,(2*r+1)**2) + corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5) + return corr diff --git a/imcui/third_party/gim/networks/dkm/utils/transforms.py b/imcui/third_party/gim/networks/dkm/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..754d853fda4cbcf89d2111bed4f44b0ca84f0518 --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/utils/transforms.py @@ -0,0 +1,104 @@ +from typing import Dict +import numpy as np +import torch +import kornia.augmentation as K +from kornia.geometry.transform import warp_perspective + +# Adapted from Kornia +class GeometricSequential: + def __init__(self, *transforms, align_corners=True) -> None: + self.transforms = transforms + self.align_corners = align_corners + + def __call__(self, x, mode="bilinear"): + b, c, h, w = x.shape + M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) + for t in self.transforms: + if np.random.rand() < t.p: + M = M.matmul( + t.compute_transformation(x, t.generate_parameters((b, c, h, w))) + ) + return ( + warp_perspective( + x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners + ), + M, + ) + + def apply_transform(self, x, M, mode="bilinear"): + b, c, h, w = x.shape + return warp_perspective( + x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode + ) + + +class RandomPerspective(K.RandomPerspective): + def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: + distortion_scale = torch.as_tensor( + self.distortion_scale, device=self._device, dtype=self._dtype + ) + return self.random_perspective_generator( + batch_shape[0], + batch_shape[-2], + batch_shape[-1], + distortion_scale, + self.same_on_batch, + self.device, + self.dtype, + ) + + def random_perspective_generator( + self, + batch_size: int, + height: int, + width: int, + distortion_scale: torch.Tensor, + same_on_batch: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Dict[str, torch.Tensor]: + r"""Get parameters for ``perspective`` for a random perspective transform. + + Args: + batch_size (int): the tensor batch size. + height (int) : height of the image. + width (int): width of the image. + distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. + same_on_batch (bool): apply the same transformation across the batch. Default: False. + device (torch.device): the device on which the random numbers will be generated. Default: cpu. + dtype (torch.dtype): the data type of the generated random numbers. Default: float32. + + Returns: + params Dict[str, torch.Tensor]: parameters to be passed for transformation. + - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). + - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). + + Note: + The generated random numbers are not reproducible across different devices and dtypes. + """ + if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): + raise AssertionError( + f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." + ) + if not ( + type(height) is int and height > 0 and type(width) is int and width > 0 + ): + raise AssertionError( + f"'height' and 'width' must be integers. Got {height}, {width}." + ) + + start_points: torch.Tensor = torch.tensor( + [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], + device=distortion_scale.device, + dtype=distortion_scale.dtype, + ).expand(batch_size, -1, -1) + + # generate random offset not larger than half of the image + fx = distortion_scale * width / 2 + fy = distortion_scale * height / 2 + + factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) + offset = (torch.rand_like(start_points) - 0.5) * 2 + end_points = start_points + factor * offset + + return dict(start_points=start_points, end_points=end_points) diff --git a/imcui/third_party/gim/networks/dkm/utils/utils.py b/imcui/third_party/gim/networks/dkm/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed50774dcc690e5afbdf65a9c7e87bc0a6c4706 --- /dev/null +++ b/imcui/third_party/gim/networks/dkm/utils/utils.py @@ -0,0 +1,341 @@ +import numpy as np +import cv2 +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torch.nn.functional as F +from PIL import Image + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + +def rotate_intrinsic(K, n): + base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + rot = np.linalg.matrix_power(base_rot, n) + return rot @ K + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + + +# From Patch2Pix https://github.com/GrumpyZhou/patch2pix +def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) + return TupleCompose(ops) + + +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize)) + if normalize: + ops.append(TupleToTensorScaled()) + # ops.append( + # TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + # ) # Imagenet mean/std + else: + if unscale: + ops.append(TupleToTensorUnscaled()) + else: + ops.append(TupleToTensorScaled()) + return TupleCompose(ops) + + +class ToTensorScaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" + + def __call__(self, im): + if not isinstance(im, torch.Tensor): + im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) + im /= 255.0 + return torch.from_numpy(im) + else: + return im + + def __repr__(self): + return "ToTensorScaled(./255)" + + +class TupleToTensorScaled(object): + def __init__(self): + self.to_tensor = ToTensorScaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorScaled(./255)" + + +class ToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __call__(self, im): + return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) + + def __repr__(self): + return "ToTensorUnscaled()" + + +class TupleToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __init__(self): + self.to_tensor = ToTensorUnscaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorUnscaled()" + + +class TupleResize(object): + def __init__(self, size, mode=InterpolationMode.BICUBIC): + self.size = size + self.resize = transforms.Resize(size, mode) + + def __call__(self, im_tuple): + return [self.resize(im) for im in im_tuple] + + def __repr__(self): + return "TupleResize(size={})".format(self.size) + + +class TupleNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + self.normalize = transforms.Normalize(mean=mean, std=std) + + def __call__(self, im_tuple): + return [self.normalize(im) for im in im_tuple] + + def __repr__(self): + return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) + + +class TupleCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, im_tuple): + for t in self.transforms: + im_tuple = t(im_tuple) + return im_tuple + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode="bilinear")[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode="bilinear" + )[:, 0, :, 0] + consistent_mask = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() < 0.05 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 + + +imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) +imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + + +def numpy_to_pil(x: np.ndarray): + """ + Args: + x: Assumed to be of shape (h,w,c) + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if x.max() <= 1.01: + x *= 255 + x = x.astype(np.uint8) + return Image.fromarray(x) + + +def tensor_to_pil(x, unnormalize=False): + if unnormalize: + x = x * imagenet_std[:, None, None] + imagenet_mean[:, None, None] + x = x.detach().permute(1, 2, 0).cpu().numpy() + x = np.clip(x, 0.0, 1.0) + return numpy_to_pil(x) + + +def to_cuda(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + return batch + + +def to_cpu(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cpu() + return batch + + +def get_pose(calib): + w, h = np.array(calib["imsize"])[0] + return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w + + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans diff --git a/imcui/third_party/gim/networks/lightglue/__init__.py b/imcui/third_party/gim/networks/lightglue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd8fb3b1e3b54f80fdf70688fb4e4705305a723 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/__init__.py @@ -0,0 +1,17 @@ +import logging + +# from .utils.experiments import load_experiment # noqa: F401 + +formatter = logging.Formatter( + fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S" +) +handler = logging.StreamHandler() +handler.setFormatter(formatter) +handler.setLevel(logging.INFO) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False + +__module_name__ = __name__ diff --git a/imcui/third_party/gim/networks/lightglue/matching.py b/imcui/third_party/gim/networks/lightglue/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..bf718592915d6ed96782543ae4586241815a1298 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/matching.py @@ -0,0 +1,50 @@ +import torch + +from .superpoint import SuperPoint +from .models.matchers.lightglue import LightGlue + + +class Matching(torch.nn.Module): + """ Image Matching Frontend (SuperPoint + SuperGlue) """ + + # noinspection PyDefaultArgument + def __init__(self, config={}): + super().__init__() + self.detector = SuperPoint({ + 'max_num_keypoints': 2048, + 'force_num_keypoints': True, + 'detection_threshold': 0.0, + 'nms_radius': 3, + 'trainable': False, + }) + self.model = LightGlue({ + 'filter_threshold': 0.1, + 'flash': False, + 'checkpointed': True, + }) + + def forward(self, data): + """ Run SuperPoint (optionally) and SuperGlue + SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input + Args: + data: dictionary with minimal keys: ['image0', 'image1'] + """ + pred = {} + + pred.update({k + '0': v for k, v in self.detector({ + "image": data["gray0"], + "image_size": data["size0"], + }).items()}) + pred.update({k + '1': v for k, v in self.detector({ + "image": data["gray1"], + "image_size": data["size1"], + }).items()}) + + pred.update(self.model({ + **pred, **{ + 'resize0': data['size0'], + 'resize1': data['size1'] + } + })) + + return pred diff --git a/imcui/third_party/gim/networks/lightglue/models/__init__.py b/imcui/third_party/gim/networks/lightglue/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d1a05c66bbc22a711cb968be00985a31a3dfd5 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/models/__init__.py @@ -0,0 +1,30 @@ +import importlib.util + +from ..utils.tools import get_class +from .base_model import BaseModel + + +def get_model(name): + import_paths = [ + name, + f"{__name__}.{name}", + f"{__name__}.extractors.{name}", # backward compatibility + f"{__name__}.matchers.{name}", # backward compatibility + ] + for path in import_paths: + try: + spec = importlib.util.find_spec(path) + except ModuleNotFoundError: + spec = None + if spec is not None: + try: + return get_class(path, BaseModel) + except AssertionError: + mod = __import__(path, fromlist=[""]) + try: + return mod.__main_model__ + except AttributeError as exc: + print(exc) + continue + + raise RuntimeError(f'Model {name} not found in any of [{" ".join(import_paths)}]') diff --git a/imcui/third_party/gim/networks/lightglue/models/base_model.py b/imcui/third_party/gim/networks/lightglue/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f66288b9f724468c4409171b9c374c794ae9c9 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/models/base_model.py @@ -0,0 +1,157 @@ +""" +Base class for trainable models. +""" + +from abc import ABCMeta, abstractmethod +from copy import copy + +import omegaconf +from omegaconf import OmegaConf +from torch import nn + + +class MetaModel(ABCMeta): + def __prepare__(name, bases, **kwds): + total_conf = OmegaConf.create() + for base in bases: + for key in ("base_default_conf", "default_conf"): + update = getattr(base, key, {}) + if isinstance(update, dict): + update = OmegaConf.create(update) + total_conf = OmegaConf.merge(total_conf, update) + return dict(base_default_conf=total_conf) + + +class BaseModel(nn.Module, metaclass=MetaModel): + """ + What the child model is expect to declare: + default_conf: dictionary of the default configuration of the model. + It recursively updates the default_conf of all parent classes, and + it is updated by the user-provided configuration passed to __init__. + Configurations can be nested. + + required_data_keys: list of expected keys in the input data dictionary. + + strict_conf (optional): boolean. If false, BaseModel does not raise + an error when the user provides an unknown configuration entry. + + _init(self, conf): initialization method, where conf is the final + configuration object (also accessible with `self.conf`). Accessing + unknown configuration entries will raise an error. + + _forward(self, data): method that returns a dictionary of batched + prediction tensors based on a dictionary of batched input data tensors. + + loss(self, pred, data): method that returns a dictionary of losses, + computed from model predictions and input data. Each loss is a batch + of scalars, i.e. a torch.Tensor of shape (B,). + The total loss to be optimized has the key `'total'`. + + metrics(self, pred, data): method that returns a dictionary of metrics, + each as a batch of scalars. + """ + + default_conf = { + "name": None, + "trainable": True, # if false: do not optimize this model parameters + "freeze_batch_normalization": False, # use test-time statistics + "timeit": False, # time forward pass + } + required_data_keys = [] + strict_conf = False + + are_weights_initialized = False + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + default_conf = OmegaConf.merge( + self.base_default_conf, OmegaConf.create(self.default_conf) + ) + if self.strict_conf: + OmegaConf.set_struct(default_conf, True) + + # fixme: backward compatibility + if "pad" in conf and "pad" not in default_conf: # backward compat. + with omegaconf.read_write(conf): + with omegaconf.open_dict(conf): + conf["interpolation"] = {"pad": conf.pop("pad")} + + if isinstance(conf, dict): + conf = OmegaConf.create(conf) + self.conf = conf = OmegaConf.merge(default_conf, conf) + OmegaConf.set_readonly(conf, True) + OmegaConf.set_struct(conf, True) + self.required_data_keys = copy(self.required_data_keys) + self._init(conf) + + if not conf.trainable: + for p in self.parameters(): + p.requires_grad = False + + def train(self, mode=True): + super().train(mode) + + def freeze_bn(module): + if isinstance(module, nn.modules.batchnorm._BatchNorm): + module.eval() + + if self.conf.freeze_batch_normalization: + self.apply(freeze_bn) + + return self + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + + def recursive_key_check(expected, given): + for key in expected: + assert key in given, f"Missing key {key} in data" + if isinstance(expected, dict): + recursive_key_check(expected[key], given[key]) + + recursive_key_check(self.required_data_keys, data) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def loss(self, pred, data): + """To be implemented by the child class.""" + raise NotImplementedError + + def load_state_dict(self, *args, **kwargs): + """Load the state dict of the model, and set the model to initialized.""" + ret = super().load_state_dict(*args, **kwargs) + self.set_initialized() + return ret + + def is_initialized(self): + """Recursively check if the model is initialized, i.e. weights are loaded""" + is_initialized = True # initialize to true and perform recursive and + for _, w in self.named_children(): + if isinstance(w, BaseModel): + # if children is BaseModel, we perform recursive check + is_initialized = is_initialized and w.is_initialized() + else: + # else, we check if self is initialized or the children has no params + n_params = len(list(w.parameters())) + is_initialized = is_initialized and ( + n_params == 0 or self.are_weights_initialized + ) + return is_initialized + + def set_initialized(self, to: bool = True): + """Recursively set the initialization state.""" + self.are_weights_initialized = to + for _, w in self.named_parameters(): + if isinstance(w, BaseModel): + w.set_initialized(to) diff --git a/imcui/third_party/gim/networks/lightglue/models/matchers/__init__.py b/imcui/third_party/gim/networks/lightglue/models/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/gim/networks/lightglue/models/matchers/lightglue.py b/imcui/third_party/gim/networks/lightglue/models/matchers/lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..364194e8a6829c124e9a1959b8c224cb9119f211 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/models/matchers/lightglue.py @@ -0,0 +1,632 @@ +import warnings +from pathlib import Path +from typing import Callable, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf +from torch import nn +from torch.utils.checkpoint import checkpoint + +# from ...settings import DATA_PATH +# from ..utils.losses import NLLLoss +# from ..utils.metrics import matcher_metrics + +FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") + +torch.backends.cudnn.deterministic = True + + +@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) +def normalize_keypoints( + kpts: torch.Tensor, size: Optional[torch.Tensor] = None +) -> torch.Tensor: + if size is None: + size = 1 + kpts.max(-2).values - kpts.min(-2).values + elif not isinstance(size, torch.Tensor): + size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype) + size = size.to(kpts) + shift = size / 2 + scale = size.max(-1).values / 2 + kpts = (kpts - shift[..., None, :]) / scale[..., None, None] + return kpts + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """encode position vector""" + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class TokenConfidence(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid()) + self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """get confidence tokens""" + return ( + self.token(desc0.detach()).squeeze(-1), + self.token(desc1.detach()).squeeze(-1), + ) + + def loss(self, desc0, desc1, la_now, la_final): + logit0 = self.token[0](desc0.detach()).squeeze(-1) + logit1 = self.token[0](desc1.detach()).squeeze(-1) + la_now, la_final = la_now.detach(), la_final.detach() + correct0 = ( + la_final[:, :-1, :].max(-1).indices == la_now[:, :-1, :].max(-1).indices + ) + correct1 = ( + la_final[:, :, :-1].max(-2).indices == la_now[:, :, :-1].max(-2).indices + ) + return ( + self.loss_fn(logit0, correct0.float()).mean(-1) + + self.loss_fn(logit1, correct1.float()).mean(-1) + ) / 2.0 + + +class Attention(nn.Module): + def __init__(self, allow_flash: bool) -> None: + super().__init__() + if allow_flash and not FLASH_AVAILABLE: + warnings.warn( + "FlashAttention is not available. For optimal speed, " + "consider installing torch >= 2.0 or flash-attn.", + stacklevel=2, + ) + self.enable_flash = allow_flash and FLASH_AVAILABLE + + if FLASH_AVAILABLE: + torch.backends.cuda.enable_flash_sdp(allow_flash) + + def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.enable_flash and q.device.type == "cuda": + # use torch 2.0 scaled_dot_product_attention with flash + if FLASH_AVAILABLE: + args = [x.half().contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype) + return v if mask is None else v.nan_to_num() + elif FLASH_AVAILABLE: + args = [x.contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask) + return v if mask is None else v.nan_to_num() + else: + s = q.shape[-1] ** -0.5 + sim = torch.einsum("...id,...jd->...ij", q, k) * s + if mask is not None: + sim.masked_fill(~mask, -float("inf")) + attn = F.softmax(sim, -1) + return torch.einsum("...ij,...jd->...id", attn, v) + + +class SelfBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0 + self.head_dim = self.embed_dim // num_heads + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + self.inner_attn = Attention(flash) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + + def forward( + self, + x: torch.Tensor, + encoding: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.Wqkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + context = self.inner_attn(q, k, v, mask=mask) + message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2)) + return x + self.ffn(torch.cat([x, message], -1)) + + +class CrossBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.heads = num_heads + dim_head = embed_dim // num_heads + self.scale = dim_head**-0.5 + inner_dim = dim_head * num_heads + self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + if flash and FLASH_AVAILABLE: + self.flash = Attention(True) + else: + self.flash = None + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward( + self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> List[torch.Tensor]: + qk0, qk1 = self.map_(self.to_qk, x0, x1) + v0, v1 = self.map_(self.to_v, x0, x1) + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1), + ) + if self.flash is not None and qk0.device.type == "cuda": + m0 = self.flash(qk0, qk1, v1, mask) + m1 = self.flash( + qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None + ) + else: + qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 + sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1) + if mask is not None: + sim = sim.masked_fill(~mask, -float("inf")) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) + m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) + if mask is not None: + m0, m1 = m0.nan_to_num(), m1.nan_to_num() + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) + m0, m1 = self.map_(self.to_out, m0, m1) + x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) + x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) + return x0, x1 + + +class TransformerLayer(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.self_attn = SelfBlock(*args, **kwargs) + self.cross_attn = CrossBlock(*args, **kwargs) + + def forward( + self, + desc0, + desc1, + encoding0, + encoding1, + mask0: Optional[torch.Tensor] = None, + mask1: Optional[torch.Tensor] = None, + ): + if mask0 is not None and mask1 is not None: + return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1) + else: + desc0 = self.self_attn(desc0, encoding0) + desc1 = self.self_attn(desc1, encoding1) + return self.cross_attn(desc0, desc1) + + # This part is compiled and allows padding inputs + def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1): + mask = mask0 & mask1.transpose(-1, -2) + mask0 = mask0 & mask0.transpose(-1, -2) + mask1 = mask1 & mask1.transpose(-1, -2) + desc0 = self.self_attn(desc0, encoding0, mask0) + desc1 = self.self_attn(desc1, encoding1, mask1) + return self.cross_attn(desc0, desc1, mask) + + +def sigmoid_log_double_softmax( + sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + b, m, n = sim.shape + certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2) + scores0 = F.log_softmax(sim, 2) + scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = sim.new_full((b, m + 1, n + 1), 0) + scores[:, :m, :n] = scores0 + scores1 + certainties + scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) + scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) + return scores + + +class MatchAssignment(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + self.matchability = nn.Linear(dim, 1, bias=True) + self.final_proj = nn.Linear(dim, dim, bias=True) + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """build assignment matrix from descriptors""" + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) + _, _, d = mdesc0.shape + mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25 + sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1) + z0 = self.matchability(desc0) + z1 = self.matchability(desc1) + scores = sigmoid_log_double_softmax(sim, z0, z1) + return scores, sim + + def get_matchability(self, desc: torch.Tensor): + return torch.sigmoid(self.matchability(desc)).squeeze(-1) + + +def filter_matches(scores: torch.Tensor, th: float): + """obtain matches from a log assignment matrix [Bx M+1 x N+1]""" + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + m0, m1 = max0.indices, max1.indices + indices0 = torch.arange(m0.shape[1], device=m0.device)[None] + indices1 = torch.arange(m1.shape[1], device=m1.device)[None] + mutual0 = indices0 == m1.gather(1, m0) + mutual1 = indices1 == m0.gather(1, m1) + max0_exp = max0.values.exp() + zero = max0_exp.new_tensor(0) + mscores0 = torch.where(mutual0, max0_exp, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) + valid0 = mutual0 & (mscores0 > th) + valid1 = mutual1 & valid0.gather(1, m1) + m0 = torch.where(valid0, m0, -1) + m1 = torch.where(valid1, m1, -1) + return m0, m1, mscores0, mscores1 + + +class LightGlue(nn.Module): + default_conf = { + "name": "lightglue", # just for interfacing + "input_dim": 256, # input descriptor dimension (autoselected from weights) + "add_scale_ori": False, + "descriptor_dim": 256, + "n_layers": 9, + "num_heads": 4, + "flash": False, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "depth_confidence": -1, # early stopping, disable with -1 + "width_confidence": -1, # point pruning, disable with -1 + "filter_threshold": 0.0, # match threshold + "checkpointed": False, + "weights": "superpoint_lightglue", # either a path or the name of pretrained weights (disk, ...) + "weights_from_version": "v0.1_arxiv", + "loss": { + "gamma": 1.0, + "fn": "nll", + "nll_balancing": 0.5, + }, + } + + required_data_keys = ["keypoints0", "keypoints1", "descriptors0", "descriptors1"] + + url = "https://github.com/cvg/LightGlue/releases/download/{}/{}.pth" + + def __init__(self, conf) -> None: + super().__init__() + self.conf = conf = OmegaConf.merge(self.default_conf, conf) + if conf.input_dim != conf.descriptor_dim: + self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) + else: + self.input_proj = nn.Identity() + + head_dim = conf.descriptor_dim // conf.num_heads + self.posenc = LearnableFourierPositionalEncoding( + 2 + 2 * conf.add_scale_ori, head_dim, head_dim + ) + + h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim + + self.transformers = nn.ModuleList( + [TransformerLayer(d, h, conf.flash) for _ in range(n)] + ) + + self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) + self.token_confidence = nn.ModuleList( + [TokenConfidence(d) for _ in range(n - 1)] + ) + + # self.loss_fn = NLLLoss(conf.loss) + + # state_dict = None + # if conf.weights is not None: + # # weights can be either a path or an existing file from official LG + # if Path(conf.weights).exists(): + # state_dict = torch.load(conf.weights, map_location="cpu") + # elif (Path(DATA_PATH) / conf.weights).exists(): + # state_dict = torch.load( + # str(DATA_PATH / conf.weights), map_location="cpu" + # ) + # elif (Path('weights') / (conf.weights + '.pth')).exists(): + # state_dict = torch.load( + # str(Path('weights') / (conf.weights + '.pth')), map_location="cpu" + # ) + # print(f"Readed weights from {Path('weights') / (conf.weights + '.pth')}") + # else: + # fname = ( + # f"{conf.weights}_{conf.weights_from_version}".replace(".", "-") + # + ".pth" + # ) + # state_dict = torch.hub.load_state_dict_from_url( + # self.url.format(conf.weights_from_version, conf.weights), + # file_name=fname, + # ) + # + # if state_dict: + # # rename old state dict entries + # for i in range(self.conf.n_layers): + # pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" + # state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + # pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" + # state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + # self.load_state_dict(state_dict, strict=False) + # print(f"Loaded weights from {conf.weights}") + + def compile(self, mode="reduce-overhead"): + if self.conf.width_confidence != -1: + warnings.warn( + "Point pruning is partially disabled for compiled forward.", + stacklevel=2, + ) + + for i in range(self.conf.n_layers): + self.transformers[i] = torch.compile( + self.transformers[i], mode=mode, fullgraph=True + ) + + def forward(self, data: dict) -> dict: + for key in self.required_data_keys: + assert key in data, f"Missing key {key} in data" + + kpts0, kpts1 = data["keypoints0"], data["keypoints1"] + b, m, _ = kpts0.shape + b, n, _ = kpts1.shape + device = kpts0.device + # if "view0" in data.keys() and "view1" in data.keys(): + size0 = data["image_size0"][:, [1, 0]] if "image_size0" in data.keys() else data["resize0"][:, [1, 0]] + size1 = data["image_size1"][:, [1, 0]] if "image_size1" in data.keys() else data["resize1"][:, [1, 0]] + kpts0 = normalize_keypoints(kpts0, size0).clone() + kpts1 = normalize_keypoints(kpts1, size1).clone() + + if self.conf.add_scale_ori: + sc0, o0 = data["scales0"], data["oris0"] + sc1, o1 = data["scales1"], data["oris1"] + kpts0 = torch.cat( + [ + kpts0, + sc0 if sc0.dim() == 3 else sc0[..., None], + o0 if o0.dim() == 3 else o0[..., None], + ], + -1, + ) + kpts1 = torch.cat( + [ + kpts1, + sc1 if sc1.dim() == 3 else sc1[..., None], + o1 if o1.dim() == 3 else o1[..., None], + ], + -1, + ) + + desc0 = data["descriptors0"].contiguous() + desc1 = data["descriptors1"].contiguous() + + assert desc0.shape[-1] == self.conf.input_dim + assert desc1.shape[-1] == self.conf.input_dim + if torch.is_autocast_enabled(): + desc0 = desc0.half() + desc1 = desc1.half() + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + # cache positional embeddings + encoding0 = self.posenc(kpts0) + encoding1 = self.posenc(kpts1) + + # GNN + final_proj + assignment + do_early_stop = self.conf.depth_confidence > 0 and not self.training + do_point_pruning = self.conf.width_confidence > 0 and not self.training + + all_desc0, all_desc1 = [], [] + + if do_point_pruning: + ind0 = torch.arange(0, m, device=device)[None] + ind1 = torch.arange(0, n, device=device)[None] + # We store the index of the layer at which pruning is detected. + prune0 = torch.ones_like(ind0) + prune1 = torch.ones_like(ind1) + token0, token1 = None, None + for i in range(self.conf.n_layers): + if self.conf.checkpointed and self.training: + desc0, desc1 = checkpoint( + self.transformers[i], desc0, desc1, encoding0, encoding1 + ) + else: + desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1) + if self.training or i == self.conf.n_layers - 1: + all_desc0.append(desc0) + all_desc1.append(desc1) + continue # no early stopping or adaptive width at last layer + + # only for eval + if do_early_stop: + assert b == 1 + token0, token1 = self.token_confidence[i](desc0, desc1) + if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n): + break + if do_point_pruning: + assert b == 1 + scores0 = self.log_assignment[i].get_matchability(desc0) + prunemask0 = self.get_pruning_mask(token0, scores0, i) + keep0 = torch.where(prunemask0)[1] + ind0 = ind0.index_select(1, keep0) + desc0 = desc0.index_select(1, keep0) + encoding0 = encoding0.index_select(-2, keep0) + prune0[:, ind0] += 1 + scores1 = self.log_assignment[i].get_matchability(desc1) + prunemask1 = self.get_pruning_mask(token1, scores1, i) + keep1 = torch.where(prunemask1)[1] + ind1 = ind1.index_select(1, keep1) + desc1 = desc1.index_select(1, keep1) + encoding1 = encoding1.index_select(-2, keep1) + prune1[:, ind1] += 1 + + desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] + scores, _ = self.log_assignment[i](desc0, desc1) + m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) + matches, mscores = [], [] + for k in range(b): + if self.training: break + valid = m0[k] > -1 + m_indices_0 = torch.where(valid)[0] + m_indices_1 = m0[k][valid] + if do_point_pruning: + m_indices_0 = ind0[k, m_indices_0] + m_indices_1 = ind1[k, m_indices_1] + matches.append(torch.stack([m_indices_0, m_indices_1], -1)) + mscores.append(mscores0[k][valid]) + + if do_point_pruning: + m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype) + m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype) + m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) + m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) + mscores0_ = torch.zeros((b, m), device=mscores0.device) + mscores1_ = torch.zeros((b, n), device=mscores1.device) + mscores0_[:, ind0] = mscores0 + mscores1_[:, ind1] = mscores1 + m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_ + else: + prune0 = torch.ones_like(mscores0) * self.conf.n_layers + prune1 = torch.ones_like(mscores1) * self.conf.n_layers + + pred = { + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "ref_descriptors0": torch.stack(all_desc0, 1), + "ref_descriptors1": torch.stack(all_desc1, 1), + "log_assignment": scores, + "stop": i + 1, + "matches": matches, + "scores": mscores, + "prune0": prune0, + "prune1": prune1, + } + + return pred + + def confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers) + return np.clip(threshold, 0, 1) + + def get_pruning_mask( + self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int + ) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.conf.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self.confidence_thresholds[layer_index] + return keep + + def check_if_stop( + self, + confidences0: torch.Tensor, + confidences1: torch.Tensor, + layer_index: int, + num_points: int, + ) -> torch.Tensor: + """evaluate stopping condition""" + confidences = torch.cat([confidences0, confidences1], -1) + threshold = self.confidence_thresholds[layer_index] + ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points + return ratio_confident > self.conf.depth_confidence + + def pruning_min_kpts(self, device: torch.device): + if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda": + return self.pruning_keypoint_thresholds["flash"] + else: + return self.pruning_keypoint_thresholds[device.type] + + def loss(self, pred, data): + def loss_params(pred, i): + la, _ = self.log_assignment[i]( + pred["ref_descriptors0"][:, i], pred["ref_descriptors1"][:, i] + ) + return { + "log_assignment": la, + } + + sum_weights = 1.0 + nll, gt_weights, loss_metrics = self.loss_fn(loss_params(pred, -1), data) + N = pred["ref_descriptors0"].shape[1] + losses = {"total": nll, "last": nll.clone().detach(), **loss_metrics} + + if self.training: + losses["confidence"] = 0.0 + + # B = pred['log_assignment'].shape[0] + losses["row_norm"] = pred["log_assignment"].exp()[:, :-1].sum(2).mean(1) + for i in range(N - 1): + params_i = loss_params(pred, i) + nll, _, _ = self.loss_fn(params_i, data, weights=gt_weights) + + if self.conf.loss.gamma > 0.0: + weight = self.conf.loss.gamma ** (N - i - 1) + else: + weight = i + 1 + sum_weights += weight + losses["total"] = losses["total"] + nll * weight + + losses["confidence"] += self.token_confidence[i].loss( + pred["ref_descriptors0"][:, i], + pred["ref_descriptors1"][:, i], + params_i["log_assignment"], + pred["log_assignment"], + ) / (N - 1) + + del params_i + losses["total"] /= sum_weights + + # confidences + if self.training: + losses["total"] = losses["total"] + losses["confidence"] + + if not self.training: + # add metrics + metrics = matcher_metrics(pred, data) + else: + metrics = {} + return losses, metrics + + +__main_model__ = LightGlue diff --git a/imcui/third_party/gim/networks/lightglue/models/utils/__init__.py b/imcui/third_party/gim/networks/lightglue/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/gim/networks/lightglue/models/utils/misc.py b/imcui/third_party/gim/networks/lightglue/models/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e86d1add0e23a042963d878e484f0c582ff8b41c --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/models/utils/misc.py @@ -0,0 +1,70 @@ +import math +from typing import List, Optional, Tuple + +import torch + + +def to_sequence(map): + return map.flatten(-2).transpose(-1, -2) + + +def to_map(sequence): + n = sequence.shape[-2] + e = math.isqrt(n) + assert e * e == n + assert e * e == n + sequence.transpose(-1, -2).unflatten(-1, [e, e]) + + +def pad_to_length( + x, + length: int, + pad_dim: int = -2, + mode: str = "zeros", # zeros, ones, random, random_c + bounds: Tuple[int] = (None, None), +): + shape = list(x.shape) + d = x.shape[pad_dim] + assert d <= length + if d == length: + return x + shape[pad_dim] = length - d + + low, high = bounds + + if mode == "zeros": + xn = torch.zeros(*shape, device=x.device, dtype=x.dtype) + elif mode == "ones": + xn = torch.ones(*shape, device=x.device, dtype=x.dtype) + elif mode == "random": + low = low if low is not None else x.min() + high = high if high is not None else x.max() + xn = torch.empty(*shape, device=x.device).uniform_(low, high) + elif mode == "random_c": + low, high = bounds # we use the bounds as fallback for empty seq. + xn = torch.cat( + [ + torch.empty(*shape[:-1], 1, device=x.device).uniform_( + x[..., i].min() if d > 0 else low, + x[..., i].max() if d > 0 else high, + ) + for i in range(shape[-1]) + ], + dim=-1, + ) + else: + raise ValueError(mode) + return torch.cat([x, xn], dim=pad_dim) + + +def pad_and_stack( + sequences: List[torch.Tensor], + length: Optional[int] = None, + pad_dim: int = -2, + **kwargs, +): + if length is None: + length = max([x.shape[pad_dim] for x in sequences]) + + y = torch.stack([pad_to_length(x, length, pad_dim, **kwargs) for x in sequences], 0) + return y diff --git a/imcui/third_party/gim/networks/lightglue/superpoint.py b/imcui/third_party/gim/networks/lightglue/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8e93591f5d64f345b07545e91108c110256a6f32 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/superpoint.py @@ -0,0 +1,360 @@ +""" +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +Described in: + SuperPoint: Self-Supervised Interest Point Detection and Description, + Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018. + +Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork + +Adapted by Philipp Lindenberger (Phil26AT) +""" +import os.path + +import torch +from torch import nn + +from networks.lightglue.models.base_model import BaseModel +from networks.lightglue.models.utils.misc import pad_and_stack + + +def simple_nms(scores, radius): + """Perform non maximum suppression on the heatmap using max-pooling. + This method does not suppress contiguous points that have the same score. + Args: + scores: the score heatmap of size `(B, H, W)`. + radius: an integer scalar, the radius of the NMS window. + """ + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=radius * 2 + 1, stride=1, padding=radius + ) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def top_k_keypoints(keypoints, scores, k): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0, sorted=True) + return keypoints[indices], scores + + +def sample_k_keypoints(keypoints, scores, k): + if k >= len(keypoints): + return keypoints, scores + indices = torch.multinomial(scores, k, replacement=False) + return keypoints[indices], scores[indices] + + +def soft_argmax_refinement(keypoints, scores, radius: int): + width = 2 * radius + 1 + sum_ = torch.nn.functional.avg_pool2d( + scores[:, None], width, 1, radius, divisor_override=1 + ) + ar = torch.arange(-radius, radius + 1).to(scores) + kernel_x = ar[None].expand(width, -1)[None, None] + dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius) + dy = torch.nn.functional.conv2d( + scores[:, None], kernel_x.transpose(2, 3), padding=radius + ) + dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None] + refined_keypoints = [] + for i, kpts in enumerate(keypoints): + delta = dydx[i][tuple(kpts.t())] + refined_keypoints.append(kpts.float() + delta) + return refined_keypoints + + +# Legacy (broken) sampling of the descriptors +def sample_descriptors(keypoints, descriptors, s): + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor( + [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to( + keypoints + )[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {"align_corners": True} if torch.__version__ >= "1.3" else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + + +# The original keypoint sampling is incorrect. We patch it here but +# keep the original one above for legacy. +def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): + """Interpolate descriptors at keypoint locations""" + b, c, h, w = descriptors.shape + keypoints = keypoints / (keypoints.new_tensor([w, h]) * s) + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + + +class SuperPoint(BaseModel): + default_conf = { + "has_detector": True, + "has_descriptor": True, + "descriptor_dim": 256, + # Inference + "sparse_outputs": True, + "dense_outputs": False, + "nms_radius": 4, + "refinement_radius": 0, + "detection_threshold": 0.005, + "max_num_keypoints": -1, + "max_num_keypoints_val": None, + "force_num_keypoints": False, + "randomize_keypoints_training": False, + "remove_borders": 4, + "legacy_sampling": True, # True to use the old broken sampling + } + required_data_keys = ["image"] + detection_noise = 2.0 + + # checkpoint_url = "https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superpoint_v1.pth" # noqa: E501 + + def _init(self, conf): + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + + if conf.has_detector: + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + for param in self.convPa.parameters(): + param.requires_grad = False + for param in self.convPb.parameters(): + param.requires_grad = False + + if conf.has_descriptor: + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d( + c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0 + ) + + # self.load_state_dict(torch.load(os.path.join('weights', 'superpoint_v1.pth'))) + + def _forward(self, data): + image = data["image"] + data["image_size"] = torch.tensor(image.shape[-2:][::-1])[None] + if image.shape[1] == 3: # RGB + scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) + image = (image * scale).sum(1, keepdim=True) + + # Shared Encoder + x = self.relu(self.conv1a(image)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + pred = {} + if self.conf.has_detector: + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, c, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + pred["keypoint_scores"] = dense_scores = scores + if self.conf.has_descriptor: + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + dense_desc = self.convDb(cDa) + dense_desc = torch.nn.functional.normalize(dense_desc, p=2, dim=1) + pred["descriptors"] = dense_desc + + if self.conf.sparse_outputs: + assert self.conf.has_detector and self.conf.has_descriptor + + scores = simple_nms(scores, self.conf.nms_radius) + + # Discard keypoints near the image borders + if self.conf.remove_borders: + scores[:, : self.conf.remove_borders] = -1 + scores[:, :, : self.conf.remove_borders] = -1 + if "image_size" in data: + for i in range(scores.shape[0]): + w, h = data["image_size"][i] + scores[i, int(h.item()) - self.conf.remove_borders :] = -1 + scores[i, :, int(w.item()) - self.conf.remove_borders :] = -1 + else: + scores[:, -self.conf.remove_borders :] = -1 + scores[:, :, -self.conf.remove_borders :] = -1 + + # Extract keypoints + best_kp = torch.where(scores > self.conf.detection_threshold) + scores = scores[best_kp] + + # Separate into batches + keypoints = [ + torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b) + ] + scores = [scores[best_kp[0] == i] for i in range(b)] + + # Keep the k keypoints with highest score + max_kps = self.conf.max_num_keypoints + + # for val we allow different + if not self.training and self.conf.max_num_keypoints_val is not None: + max_kps = self.conf.max_num_keypoints_val + + # Keep the k keypoints with highest score + if max_kps > 0: + if self.conf.randomize_keypoints_training and self.training: + # instead of selecting top-k, sample k by score weights + keypoints, scores = list( + zip( + *[ + sample_k_keypoints(k, s, max_kps) + for k, s in zip(keypoints, scores) + ] + ) + ) + else: + keypoints, scores = list( + zip( + *[ + top_k_keypoints(k, s, max_kps) + for k, s in zip(keypoints, scores) + ] + ) + ) + keypoints, scores = list(keypoints), list(scores) + + if self.conf["refinement_radius"] > 0: + keypoints = soft_argmax_refinement( + keypoints, dense_scores, self.conf["refinement_radius"] + ) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + if self.conf.force_num_keypoints: + keypoints = pad_and_stack( + keypoints, + max_kps, + -2, + mode="random_c", + bounds=( + 0, + data.get("image_size", torch.tensor(image.shape[-2:])) + .min() + .item(), + ), + ) + scores = pad_and_stack(scores, max_kps, -1, mode="zeros") + else: + keypoints = torch.stack(keypoints, 0) + scores = torch.stack(scores, 0) + + # Extract descriptors + if (len(keypoints) == 1) or self.conf.force_num_keypoints: + # Batch sampling of the descriptors + if self.conf.legacy_sampling: + desc = sample_descriptors(keypoints, dense_desc, 8) + else: + desc = sample_descriptors_fix_sampling(keypoints, dense_desc, 8) + else: + if self.conf.legacy_sampling: + desc = [ + sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, dense_desc) + ] + else: + desc = [ + sample_descriptors_fix_sampling(k[None], d[None], 8)[0] + for k, d in zip(keypoints, dense_desc) + ] + + pred = { + "keypoints": keypoints + 0.5, + "descriptors": desc.transpose(-1, -2), + } + + if self.conf.dense_outputs: + pred["dense_descriptors"] = dense_desc + + return pred + + def loss(self, pred, data): + raise NotImplementedError + + def metrics(self, pred, data): + raise NotImplementedError diff --git a/imcui/third_party/gim/networks/lightglue/utils/__init__.py b/imcui/third_party/gim/networks/lightglue/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/utils/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/imcui/third_party/gim/networks/lightglue/utils/tools.py b/imcui/third_party/gim/networks/lightglue/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..6a27f4a491e1675557b992401208bbe4c355edd2 --- /dev/null +++ b/imcui/third_party/gim/networks/lightglue/utils/tools.py @@ -0,0 +1,269 @@ +""" +Various handy Python and PyTorch utils. + +Author: Paul-Edouard Sarlin (skydes) +""" + +import os +import random +import time +from collections.abc import Iterable +from contextlib import contextmanager + +import numpy as np +import torch + + +class AverageMetric: + def __init__(self): + self._sum = 0 + self._num_examples = 0 + + def update(self, tensor): + assert tensor.dim() == 1 + tensor = tensor[~torch.isnan(tensor)] + self._sum += tensor.sum().item() + self._num_examples += len(tensor) + + def compute(self): + if self._num_examples == 0: + return np.nan + else: + return self._sum / self._num_examples + + +# same as AverageMetric, but tracks all elements +class FAverageMetric: + def __init__(self): + self._sum = 0 + self._num_examples = 0 + self._elements = [] + + def update(self, tensor): + self._elements += tensor.cpu().numpy().tolist() + assert tensor.dim() == 1 + tensor = tensor[~torch.isnan(tensor)] + self._sum += tensor.sum().item() + self._num_examples += len(tensor) + + def compute(self): + if self._num_examples == 0: + return np.nan + else: + return self._sum / self._num_examples + + +class MedianMetric: + def __init__(self): + self._elements = [] + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + else: + return np.nanmedian(self._elements) + + +class PRMetric: + def __init__(self): + self.labels = [] + self.predictions = [] + + @torch.no_grad() + def update(self, labels, predictions, mask=None): + assert labels.shape == predictions.shape + self.labels += ( + (labels[mask] if mask is not None else labels).cpu().numpy().tolist() + ) + self.predictions += ( + (predictions[mask] if mask is not None else predictions) + .cpu() + .numpy() + .tolist() + ) + + @torch.no_grad() + def compute(self): + return np.array(self.labels), np.array(self.predictions) + + def reset(self): + self.labels = [] + self.predictions = [] + + +class QuantileMetric: + def __init__(self, q=0.05): + self._elements = [] + self.q = q + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + else: + return np.nanquantile(self._elements, self.q) + + +class RecallMetric: + def __init__(self, ths, elements=[]): + self._elements = elements + self.ths = ths + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if isinstance(self.ths, Iterable): + return [self.compute_(th) for th in self.ths] + else: + return self.compute_(self.ths[0]) + + def compute_(self, th): + if len(self._elements) == 0: + return np.nan + else: + s = (np.array(self._elements) < th).sum() + return s / len(self._elements) + + +def cal_error_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.round((np.trapz(r, x=e) / t), 4)) + return aucs + + +class AUCMetric: + def __init__(self, thresholds, elements=None): + self._elements = elements + self.thresholds = thresholds + if not isinstance(thresholds, list): + self.thresholds = [thresholds] + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + else: + return cal_error_auc(self._elements, self.thresholds) + + +class Timer(object): + """A simpler timer context object. + Usage: + ``` + > with Timer('mytimer'): + > # some computations + [mytimer] Elapsed: X + ``` + """ + + def __init__(self, name=None): + self.name = name + + def __enter__(self): + self.tstart = time.time() + return self + + def __exit__(self, type, value, traceback): + self.duration = time.time() - self.tstart + if self.name is not None: + print("[%s] Elapsed: %s" % (self.name, self.duration)) + + +def get_class(mod_path, BaseClass): + """Get the class object which inherits from BaseClass and is defined in + the module named mod_name, child of base_path. + """ + import inspect + + mod = __import__(mod_path, fromlist=[""]) + classes = inspect.getmembers(mod, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == mod_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseClass)] + assert len(classes) == 1, classes + return classes[0][1] + + +def set_num_threads(nt): + """Force numpy and other libraries to use a limited number of threads.""" + try: + import mkl + except ImportError: + pass + else: + mkl.set_num_threads(nt) + torch.set_num_threads(1) + os.environ["IPC_ENABLE"] = "1" + for o in [ + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + ]: + os.environ[o] = str(nt) + + +def set_seed(seed): + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_random_state(with_cuda): + pth_state = torch.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + if torch.cuda.is_available() and with_cuda: + cuda_state = torch.cuda.get_rng_state_all() + else: + cuda_state = None + return pth_state, np_state, py_state, cuda_state + + +def set_random_state(state): + pth_state, np_state, py_state, cuda_state = state + torch.set_rng_state(pth_state) + np.random.set_state(np_state) + random.setstate(py_state) + if ( + cuda_state is not None + and torch.cuda.is_available() + and len(cuda_state) == torch.cuda.device_count() + ): + torch.cuda.set_rng_state_all(cuda_state) + + +@contextmanager +def fork_rng(seed=None, with_cuda=True): + state = get_random_state(with_cuda) + if seed is not None: + set_seed(seed) + try: + yield + finally: + set_random_state(state) diff --git a/imcui/third_party/gim/networks/loftr/__init__.py b/imcui/third_party/gim/networks/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/imcui/third_party/gim/networks/loftr/backbone/__init__.py b/imcui/third_party/gim/networks/loftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1040aba694eeda5828ac7232e52a87ead0179a94 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/backbone/__init__.py @@ -0,0 +1,11 @@ +from .resnet import ResNetFPN_8_2 + + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/imcui/third_party/gim/networks/loftr/backbone/resnet.py b/imcui/third_party/gim/networks/loftr/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..526e38f10853ee6255f342b0faf57b67ab30a3f4 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/backbone/resnet.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from typing import Type, Callable, Union, List, Optional + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + # self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + # dilate=replace_stride_with_dilation[2]) + # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + # + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + # nn.init.constant_(m.weight, 1) + # nn.init.constant_(m.bias, 0) + # + # # Zero-initialize the last BN in each residual branch, + # # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + # if zero_init_residual: + # for m in self.modules(): + # if isinstance(m, Bottleneck): + # nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + # elif isinstance(m, BasicBlock): + # nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)] + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + # x = self.conv1(x) # (2, 64, 320, 320) + # x = self.bn1(x) # (2, 64, 320, 320) + # x1 = self.relu(x) # (2, 64, 320, 320) + # x2 = self.maxpool(x1) # (2, 64, 160, 160) + + # x2 = self.layer1(x1) # (2, 64, 160, 160) + # x3 = self.layer2(x2) # (2, 128, 80, 80) + # x4 = self.layer3(x3) # (2, 256, 40, 40) + # x = self.layer4(x) # (2, 512, 20, 20) + + # x = self.avgpool(x) # (2, 512, 1, 1) + # x = torch.flatten(x, 1) # (2, 512) + # x = self.fc(x) # (2, 1000) + + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + return x1, x2, x3 + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('layer4.'): state_dict.pop(k) + if k.startswith('fc.'): state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + # initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + # self.in_planes = initial_dim + + # Networks + # self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + # self.bn1 = nn.BatchNorm2d(initial_dim) + # self.relu = nn.ReLU(inplace=True) + + # self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + # self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + # self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + self.encode = ResNet(Bottleneck, [3, 4, 6, 3]) # resnet50 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[5], block_dims[3]) + self.layer2_outconv = conv1x1(block_dims[4], block_dims[3]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + self.layer1_outconv = conv1x1(block_dims[3], block_dims[2]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + # x0 = self.relu(self.bn1(self.conv1(x))) + # x1 = self.layer1(x0) # 1/2 + # x2 = self.layer2(x1) # 1/4 + # x3 = self.layer3(x2) # 1/8 + + # x1: (2, 64, 320, 320) + # x2: (2, 128, 160, 160) + # x3: (2, 256, 80, 80) + x1, x2, x3 = self.encode(x) + + # FPN + x3_out = self.layer3_outconv(x3) # (2, 256, 80, 80) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) # (2, 256, 160, 160) + x2_out = self.layer2_outconv(x2) # (2, 256, 160, 160) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) # (2, 196, 160, 160) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) # (2, 196, 320, 320) + x1_out = self.layer1_outconv(x1) # (2, 196, 320, 320) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +if __name__ == '__main__': + # Original form + # config = dict(initial_dim=128, block_dims=[128, 196, 256]) + # model = ResNetFPN_8_2(config) + # # output (list): + # # 0: (2, 256, 80, 80) + # # 1: (2, 128, 320, 320) + # output = model(torch.randn(2, 1, 640, 640)) + + # model = ResNet(BasicBlock, [2, 2, 2, 2]) + # # weights = torch.load('resnet18(5c106cde).ckpt', map_location='cpu') + # # model.load_state_dict(weights) + # output = model(torch.randn(2, 3, 640, 640)) + + config = dict(initial_dim=128, block_dims=[64, 128, 196, 256]) + model = ResNetFPN_8_2(config) + # output (list): + # 0: (2, 256, 80, 80) + # 1: (2, 128, 320, 320) + output = model(torch.randn(2, 3, 640, 640)) diff --git a/imcui/third_party/gim/networks/loftr/backbone/resnet_fpn.py b/imcui/third_party/gim/networks/loftr/backbone/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..18e4caf34f065aa46e05913fdccb9a93403148fc --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/backbone/resnet_fpn.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(3, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return [x4_out, x2_out] diff --git a/imcui/third_party/gim/networks/loftr/config.py b/imcui/third_party/gim/networks/loftr/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ee03c9b6fea8430318b9be64fa15bd2268a04704 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/config.py @@ -0,0 +1,77 @@ +from yacs.config import CfgNode as CN + +_CN = CN() +_CN.TEMP_BUG_FIX = True + +############## ↓ LoFTR Pipeline ↓ ############## +_CN.LOFTR = CN() +_CN.LOFTR.WEIGHT = None + +############## ↓ LoFTR Pipeline ↓ ############## +_CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN' +_CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.LOFTR.FINE_CONCAT_COARSE_FEAT = False + +# 1. LoFTR-backbone (local feature CNN) config +_CN.LOFTR.RESNETFPN = CN() +_CN.LOFTR.RESNETFPN.INITIAL_DIM = 128 +_CN.LOFTR.RESNETFPN.BLOCK_DIMS = [64, 128, 196, 256, 512, 1024] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.LOFTR.COARSE = CN() +_CN.LOFTR.COARSE.D_MODEL = 256 +_CN.LOFTR.COARSE.NHEAD = 8 +_CN.LOFTR.COARSE.LAYER_NAMES = 4 +_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] + +# 3. Coarse-Matching config +_CN.LOFTR.MATCH_COARSE = CN() +_CN.LOFTR.MATCH_COARSE.THR = 0.2 +_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2 +_CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3 +_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False +_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock +_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = False + +# 4. LoFTR-fine module config +_CN.LOFTR.FINE = CN() +_CN.LOFTR.FINE.D_MODEL = 128 +_CN.LOFTR.FINE.NHEAD = 8 +_CN.LOFTR.FINE.LAYER_NAMES = 1 +_CN.LOFTR.FINE.ATTENTION = 'linear' + +# 5. LoFTR Losses +# -- # coarse-level +_CN.LOFTR.LOSS = CN() +_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy'] +_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0 +# _CN.LOFTR.LOSS.SPARSE_SPVS = False +# -- - -- # focal loss (coarse) +_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25 +_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0 +_CN.LOFTR.LOSS.POS_WEIGHT = 1.0 +_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0 +# _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not. +# use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE` + +# -- # fine-level +_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0 +_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) + +# Overlap +_CN.LOFTR.LOSS.OVERLAP_WEIGHT = 20.0 +_CN.LOFTR.LOSS.OVERLAP_FOCAL_ALPHA = 0.25 +_CN.LOFTR.LOSS.OVERLAP_FOCAL_GAMMA = 5.0 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/imcui/third_party/gim/networks/loftr/configs/__init__.py b/imcui/third_party/gim/networks/loftr/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/configs/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/imcui/third_party/gim/networks/loftr/configs/outdoor/__init__.py b/imcui/third_party/gim/networks/loftr/configs/outdoor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..846998ecc3b1961b957a1a98afa9f1a899079ee4 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/configs/outdoor/__init__.py @@ -0,0 +1,12 @@ +from networks.loftr.config import get_cfg_defaults as get_network_cfg +from trainer.config import get_cfg_defaults as get_trainer_cfg + +# network +network_cfg = get_network_cfg() +network_cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 + +# optimizer +trainer_cfg = get_trainer_cfg() +trainer_cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +trainer_cfg.TRAINER.WARMUP_RATIO = 0.1 +trainer_cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] diff --git a/imcui/third_party/gim/networks/loftr/loftr.py b/imcui/third_party/gim/networks/loftr/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe35a373581d1e1d08bd28271e6d598c4759273 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/loftr.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .utils.position_encoding import PositionEncodingSine +from .submodules import LocalFeatureTransformer, FinePreprocess +import warnings +from .utils.coarse_matching import CoarseMatching +warnings.simplefilter("ignore", UserWarning) +from .utils.fine_matching import FineMatching + + +class LoFTR(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=False) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + + """ + outdoor_ds.ckpt: {OrderedDict: 211} + backbone: {OrderedDict: 107} + loftr_coarse: {OrderedDict: 80} + loftr_fine: {OrderedDict: 20} + fine_preprocess: {OrderedDict: 4} + """ + if config['weight'] is not None: + weights = torch.load(config['weight'], map_location='cpu') + self.load_state_dict(weights) + # print(config['weight'] + ' load success.') + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone(torch.cat([data['color0'], data['color1']], dim=0)) # h == h0 == h1, w == w0 == w1feats_c: (bs*2, 256, h//8, w//8), feats_f: (bs*2, 128, h//2, w//2) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) # feat_c0, feat_c1: (bs, 256, h//8, w//8), feat_f0, feat_f1: (bs, 128, h//2, w//2) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['color0']), self.backbone(data['color1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + b, c, h0, w0 = feat_c0.size() + _, _, h1, w1 = feat_c1.size() + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/imcui/third_party/gim/networks/loftr/misc.py b/imcui/third_party/gim/networks/loftr/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..61cd57bf1e4e5aacab58e42e9277a4ad12990dc9 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/misc.py @@ -0,0 +1,100 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() diff --git a/imcui/third_party/gim/networks/loftr/submodules/__init__.py b/imcui/third_party/gim/networks/loftr/submodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/submodules/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/imcui/third_party/gim/networks/loftr/submodules/attentions.py b/imcui/third_party/gim/networks/loftr/submodules/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..b73c5a6a6a722a44c0b68f70cb77c0988b8a5fb3 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/submodules/attentions.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() diff --git a/imcui/third_party/gim/networks/loftr/submodules/fine_preprocess.py b/imcui/third_party/gim/networks/loftr/submodules/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/submodules/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/imcui/third_party/gim/networks/loftr/submodules/transformer.py b/imcui/third_party/gim/networks/loftr/submodules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e70cafddc912901a04d2491bf6f9e9dbaaf4e793 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/submodules/transformer.py @@ -0,0 +1,101 @@ +import copy +import torch +import torch.nn as nn +from .attentions import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = ['self', 'cross'] * config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 diff --git a/imcui/third_party/gim/networks/loftr/utils/__init__.py b/imcui/third_party/gim/networks/loftr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/utils/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/imcui/third_party/gim/networks/loftr/utils/coarse_matching.py b/imcui/third_party/gim/networks/loftr/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..8f225ed3dcb6becd229302ece53d8cc8b43e42f0 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/utils/coarse_matching.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat_c0 (torch.Tensor): [N, L, C] + feat_c1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + # noinspection PyArgumentList + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat/feat.shape[-1]**.5, [feat_c0, feat_c1]) + + conf_matrix = None + if self.match_type == 'dual_softmax': + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)/self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_(~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF) + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + elif self.match_type == 'sinkhorn': + # sinkhorn, dustbin included + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) + if mask_c0 is not None: + sim_matrix[:, :L, :S].masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + + # build uniform prior & use sinkhorn + log_assign_matrix = self.log_optimal_transport( + sim_matrix, self.bin_score, self.skh_iters) + assign_matrix = log_assign_matrix.exp() + conf_matrix = assign_matrix[:, :-1, :-1] + + # filter prediction with dustbin score (only in evaluation mode) + if not self.training and self.skh_prefilter: + filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] + filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] + conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 + conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 + + if self.config['sparse_spvs']: + data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + + data.update({'conf_matrix': conf_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches diff --git a/imcui/third_party/gim/networks/loftr/utils/fine_matching.py b/imcui/third_party/gim/networks/loftr/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a0bf6096963df69e088ed826ea334d3114c67c --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/utils/fine_matching.py @@ -0,0 +1,74 @@ +import math +import torch +import torch.nn as nn + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat_f0 (torch.Tensor): [M, WW, C] + feat_f1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training is False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/imcui/third_party/gim/networks/loftr/utils/position_encoding.py b/imcui/third_party/gim/networks/loftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8a835f145d12d9024da341ca0cd53ad6ec9412d8 --- /dev/null +++ b/imcui/third_party/gim/networks/loftr/utils/position_encoding.py @@ -0,0 +1,43 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] diff --git a/imcui/third_party/gim/networks/mit_semseg/__init__.py b/imcui/third_party/gim/networks/mit_semseg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ccf2d50ca12d8706b47e69d172f66ccd19f3dd --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/__init__.py @@ -0,0 +1,5 @@ +""" +MIT CSAIL Semantic Segmentation +""" + +__version__ = '1.0.0' diff --git a/imcui/third_party/gim/networks/mit_semseg/config/__init__.py b/imcui/third_party/gim/networks/mit_semseg/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7cfcbb8ae15ef50207c000d4c838a5b68b9c43 --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/config/__init__.py @@ -0,0 +1 @@ +from .defaults import _C as cfg diff --git a/imcui/third_party/gim/networks/mit_semseg/config/defaults.py b/imcui/third_party/gim/networks/mit_semseg/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..83818ce04fce587eae76fef00e181b63541add6f --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/config/defaults.py @@ -0,0 +1,97 @@ +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() +_C.DIR = "ckpt/ade20k-resnet50dilated-ppm_deepsup" + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASET = CN() +_C.DATASET.root_dataset = "./data/" +_C.DATASET.list_train = "./data/training.odgt" +_C.DATASET.list_val = "./data/validation.odgt" +_C.DATASET.num_class = 150 +# multiscale train/test, size of short edge (int or tuple) +_C.DATASET.imgSizes = (300, 375, 450, 525, 600) +# maximum input image size of long edge +_C.DATASET.imgMaxSize = 1000 +# maxmimum downsampling rate of the network +_C.DATASET.padding_constant = 8 +# downsampling rate of the segmentation label +_C.DATASET.segm_downsampling_rate = 8 +# randomly horizontally flip images when train/test +_C.DATASET.random_flip = True + +# ----------------------------------------------------------------------------- +# Model +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# architecture of net_encoder +_C.MODEL.arch_encoder = "resnet50dilated" +# architecture of net_decoder +_C.MODEL.arch_decoder = "ppm_deepsup" +# weights to finetune net_encoder +_C.MODEL.weights_encoder = "" +# weights to finetune net_decoder +_C.MODEL.weights_decoder = "" +# number of feature channels between encoder and decoder +_C.MODEL.fc_dim = 2048 + +# ----------------------------------------------------------------------------- +# Training +# ----------------------------------------------------------------------------- +_C.TRAIN = CN() +_C.TRAIN.batch_size_per_gpu = 2 +# epochs to train for +_C.TRAIN.num_epoch = 20 +# epoch to start training. useful if continue from a checkpoint +_C.TRAIN.start_epoch = 0 +# iterations of each epoch (irrelevant to batch size) +_C.TRAIN.epoch_iters = 5000 + +_C.TRAIN.optim = "SGD" +_C.TRAIN.lr_encoder = 0.02 +_C.TRAIN.lr_decoder = 0.02 +# power in poly to drop LR +_C.TRAIN.lr_pow = 0.9 +# momentum for sgd, beta1 for adam +_C.TRAIN.beta1 = 0.9 +# weights regularizer +_C.TRAIN.weight_decay = 1e-4 +# the weighting of deep supervision loss +_C.TRAIN.deep_sup_scale = 0.4 +# fix bn params, only under finetuning +_C.TRAIN.fix_bn = False +# number of data loading workers +_C.TRAIN.workers = 16 + +# frequency to display +_C.TRAIN.disp_iter = 20 +# manual seed +_C.TRAIN.seed = 304 + +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- +_C.VAL = CN() +# currently only supports 1 +_C.VAL.batch_size = 1 +# output visualization during validation +_C.VAL.visualize = False +# the checkpoint to evaluate on +_C.VAL.checkpoint = "epoch_20.pth" + +# ----------------------------------------------------------------------------- +# Testing +# ----------------------------------------------------------------------------- +_C.TEST = CN() +# currently only supports 1 +_C.TEST.batch_size = 1 +# the checkpoint to test on +_C.TEST.checkpoint = "epoch_20.pth" +# folder to output visualization results +_C.TEST.result = "./" diff --git a/imcui/third_party/gim/networks/mit_semseg/dataset.py b/imcui/third_party/gim/networks/mit_semseg/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1657446301613b71c7b213accac67b650766d6ca --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/dataset.py @@ -0,0 +1,296 @@ +import os +import json +import torch +from torchvision import transforms +import numpy as np +from PIL import Image + + +def imresize(im, size, interp='bilinear'): + if interp == 'nearest': + resample = Image.NEAREST + elif interp == 'bilinear': + resample = Image.BILINEAR + elif interp == 'bicubic': + resample = Image.BICUBIC + else: + raise Exception('resample method undefined!') + + return im.resize(size, resample) + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, odgt, opt, **kwargs): + # parse options + self.imgSizes = opt.imgSizes + self.imgMaxSize = opt.imgMaxSize + # max down sampling rate of network to avoid rounding during conv or pooling + self.padding_constant = opt.padding_constant + + # parse the input list + self.parse_input_list(odgt, **kwargs) + + # mean and std + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): + if isinstance(odgt, list): + self.list_sample = odgt + elif isinstance(odgt, str): + self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] + + if max_sample > 0: + self.list_sample = self.list_sample[0:max_sample] + if start_idx >= 0 and end_idx >= 0: # divide file list + self.list_sample = self.list_sample[start_idx:end_idx] + + self.num_sample = len(self.list_sample) + assert self.num_sample > 0 + print('# samples: {}'.format(self.num_sample)) + + def img_transform(self, img): + # 0-255 to 0-1 + img = np.float32(np.array(img)) / 255. + img = img.transpose((2, 0, 1)) + img = self.normalize(torch.from_numpy(img.copy())) + return img + + def segm_transform(self, segm): + # to tensor, -1 to 149 + segm = torch.from_numpy(np.array(segm)).long() - 1 + return segm + + # Round x to the nearest multiple of p and x' >= x + def round2nearest_multiple(self, x, p): + return ((x - 1) // p + 1) * p + + +class TrainDataset(BaseDataset): + def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs): + super(TrainDataset, self).__init__(odgt, opt, **kwargs) + self.root_dataset = root_dataset + # down sampling rate of segm labe + self.segm_downsampling_rate = opt.segm_downsampling_rate + self.batch_per_gpu = batch_per_gpu + + # classify images into two classes: 1. h > w and 2. h <= w + self.batch_record_list = [[], []] + + # override dataset length when trainig with batch_per_gpu > 1 + self.cur_idx = 0 + self.if_shuffled = False + + def _get_sub_batch(self): + while True: + # get a sample record + this_sample = self.list_sample[self.cur_idx] + if this_sample['height'] > this_sample['width']: + self.batch_record_list[0].append(this_sample) # h > w, go to 1st class + else: + self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class + + # update current sample pointer + self.cur_idx += 1 + if self.cur_idx >= self.num_sample: + self.cur_idx = 0 + np.random.shuffle(self.list_sample) + + if len(self.batch_record_list[0]) == self.batch_per_gpu: + batch_records = self.batch_record_list[0] + self.batch_record_list[0] = [] + break + elif len(self.batch_record_list[1]) == self.batch_per_gpu: + batch_records = self.batch_record_list[1] + self.batch_record_list[1] = [] + break + return batch_records + + def __getitem__(self, index): + # NOTE: random shuffle for the first time. shuffle in __init__ is useless + if not self.if_shuffled: + np.random.seed(index) + np.random.shuffle(self.list_sample) + self.if_shuffled = True + + # get sub-batch candidates + batch_records = self._get_sub_batch() + + # resize all images' short edges to the chosen size + if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple): + this_short_size = np.random.choice(self.imgSizes) + else: + this_short_size = self.imgSizes + + # calculate the BATCH's height and width + # since we concat more than one samples, the batch's h and w shall be larger than EACH sample + batch_widths = np.zeros(self.batch_per_gpu, np.int32) + batch_heights = np.zeros(self.batch_per_gpu, np.int32) + for i in range(self.batch_per_gpu): + img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] + this_scale = min( + this_short_size / min(img_height, img_width), \ + self.imgMaxSize / max(img_height, img_width)) + batch_widths[i] = img_width * this_scale + batch_heights[i] = img_height * this_scale + + # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' + batch_width = np.max(batch_widths) + batch_height = np.max(batch_heights) + batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant)) + batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant)) + + assert self.padding_constant >= self.segm_downsampling_rate, \ + 'padding constant must be equal or large than segm downsamping rate' + batch_images = torch.zeros( + self.batch_per_gpu, 3, batch_height, batch_width) + batch_segms = torch.zeros( + self.batch_per_gpu, + batch_height // self.segm_downsampling_rate, + batch_width // self.segm_downsampling_rate).long() + + for i in range(self.batch_per_gpu): + this_record = batch_records[i] + + # load image and label + image_path = os.path.join(self.root_dataset, this_record['fpath_img']) + segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) + + img = Image.open(image_path).convert('RGB') + segm = Image.open(segm_path) + assert(segm.mode == "L") + assert(img.size[0] == segm.size[0]) + assert(img.size[1] == segm.size[1]) + + # random_flip + if np.random.choice([0, 1]): + img = img.transpose(Image.FLIP_LEFT_RIGHT) + segm = segm.transpose(Image.FLIP_LEFT_RIGHT) + + # note that each sample within a mini batch has different scale param + img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear') + segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest') + + # further downsample seg label, need to avoid seg label misalignment + segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate) + segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate) + segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0) + segm_rounded.paste(segm, (0, 0)) + segm = imresize( + segm_rounded, + (segm_rounded.size[0] // self.segm_downsampling_rate, \ + segm_rounded.size[1] // self.segm_downsampling_rate), \ + interp='nearest') + + # image transform, to torch float tensor 3xHxW + img = self.img_transform(img) + + # segm transform, to torch long tensor HxW + segm = self.segm_transform(segm) + + # put into batch arrays + batch_images[i][:, :img.shape[1], :img.shape[2]] = img + batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm + + output = dict() + output['img_data'] = batch_images + output['seg_label'] = batch_segms + return output + + def __len__(self): + return int(1e10) # It's a fake length due to the trick that every loader maintains its own list + #return self.num_sampleclass + + +class ValDataset(BaseDataset): + def __init__(self, root_dataset, odgt, opt, **kwargs): + super(ValDataset, self).__init__(odgt, opt, **kwargs) + self.root_dataset = root_dataset + + def __getitem__(self, index): + this_record = self.list_sample[index] + # load image and label + image_path = os.path.join(self.root_dataset, this_record['fpath_img']) + segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) + img = Image.open(image_path).convert('RGB') + segm = Image.open(segm_path) + assert(segm.mode == "L") + assert(img.size[0] == segm.size[0]) + assert(img.size[1] == segm.size[1]) + + ori_width, ori_height = img.size + + img_resized_list = [] + for this_short_size in self.imgSizes: + # calculate target height and width + scale = min(this_short_size / float(min(ori_height, ori_width)), + self.imgMaxSize / float(max(ori_height, ori_width))) + target_height, target_width = int(ori_height * scale), int(ori_width * scale) + + # to avoid rounding in network + target_width = self.round2nearest_multiple(target_width, self.padding_constant) + target_height = self.round2nearest_multiple(target_height, self.padding_constant) + + # resize images + img_resized = imresize(img, (target_width, target_height), interp='bilinear') + + # image transform, to torch float tensor 3xHxW + img_resized = self.img_transform(img_resized) + img_resized = torch.unsqueeze(img_resized, 0) + img_resized_list.append(img_resized) + + # segm transform, to torch long tensor HxW + segm = self.segm_transform(segm) + batch_segms = torch.unsqueeze(segm, 0) + + output = dict() + output['img_ori'] = np.array(img) + output['img_data'] = [x.contiguous() for x in img_resized_list] + output['seg_label'] = batch_segms.contiguous() + output['info'] = this_record['fpath_img'] + return output + + def __len__(self): + return self.num_sample + + +class TestDataset(BaseDataset): + def __init__(self, odgt, opt, **kwargs): + super(TestDataset, self).__init__(odgt, opt, **kwargs) + + def __getitem__(self, index): + this_record = self.list_sample[index] + # load image + image_path = this_record['fpath_img'] + img = Image.open(image_path).convert('RGB') + + ori_width, ori_height = img.size + + img_resized_list = [] + for this_short_size in self.imgSizes: + # calculate target height and width + scale = min(this_short_size / float(min(ori_height, ori_width)), + self.imgMaxSize / float(max(ori_height, ori_width))) + target_height, target_width = int(ori_height * scale), int(ori_width * scale) + + # to avoid rounding in network + target_width = self.round2nearest_multiple(target_width, self.padding_constant) + target_height = self.round2nearest_multiple(target_height, self.padding_constant) + + # resize images + img_resized = imresize(img, (target_width, target_height), interp='bilinear') + + # image transform, to torch float tensor 3xHxW + img_resized = self.img_transform(img_resized) + img_resized = torch.unsqueeze(img_resized, 0) + img_resized_list.append(img_resized) + + output = dict() + output['img_ori'] = np.array(img) + output['img_data'] = [x.contiguous() for x in img_resized_list] + output['info'] = this_record['fpath_img'] + return output + + def __len__(self): + return self.num_sample diff --git a/imcui/third_party/gim/networks/mit_semseg/models/__init__.py b/imcui/third_party/gim/networks/mit_semseg/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76b40a0a36bc2976f185dbdc344c5a7c09b65920 --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/models/__init__.py @@ -0,0 +1 @@ +from .models import ModelBuilder, SegmentationModule diff --git a/imcui/third_party/gim/networks/mit_semseg/models/hrnet.py b/imcui/third_party/gim/networks/mit_semseg/models/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..579f3c5e4979d5c3896a393e211925b6ff85c8e4 --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/models/hrnet.py @@ -0,0 +1,445 @@ +""" +This HRNet implementation is modified from the following repository: +https://github.com/HRNet/HRNet-Semantic-Segmentation +""" + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d + +BatchNorm2d = SynchronizedBatchNorm2d +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +__all__ = ['hrnetv2'] + + +model_urls = { + 'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=(height_output, width_output), + mode='bilinear', + align_corners=False) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HRNetV2(nn.Module): + def __init__(self, n_class, **kwargs): + super(HRNetV2, self).__init__() + extra = { + 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'}, + 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'}, + 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'}, + 'FINAL_CONV_KERNEL': 1 + } + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(Bottleneck, 64, 64, 4) + + self.stage2_cfg = extra['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = extra['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = extra['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, return_feature_maps=False): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + + x = torch.cat([x[0], x1, x2, x3], 1) + + # x = self.last_layer(x) + return [x] + + +def hrnetv2(pretrained=False, **kwargs): + model = HRNetV2(n_class=1000, **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['hrnetv2']), strict=False) + + return model diff --git a/imcui/third_party/gim/networks/mit_semseg/models/mobilenet.py b/imcui/third_party/gim/networks/mit_semseg/models/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0ddec4b1747dfe7b22ee61c78a4dd75187f645 --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/models/mobilenet.py @@ -0,0 +1,154 @@ +""" +This MobileNetV2 implementation is modified from the following repository: +https://github.com/tonylins/pytorch-mobilenet-v2 +""" + +import torch.nn as nn +import math +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d + +BatchNorm2d = SynchronizedBatchNorm2d + + +__all__ = ['mobilenetv2'] + + +model_urls = { + 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', +} + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, n_class=1000, input_size=224, width_mult=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + assert input_size % 32 == 0 + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel + self.features = [conv_bn(3, input_channel, 2)] + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + if i == 0: + self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) + else: + self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, n_class), + ) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def mobilenetv2(pretrained=False, **kwargs): + """Constructs a MobileNet_V2 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = MobileNetV2(n_class=1000, **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) + return model diff --git a/imcui/third_party/gim/networks/mit_semseg/models/models.py b/imcui/third_party/gim/networks/mit_semseg/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a624d08d776e26b353d41d6a48f9feaa1852ed --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/models/models.py @@ -0,0 +1,586 @@ +import torch +import torch.nn as nn +from . import resnet, resnext, mobilenet, hrnet +from ..lib.nn import SynchronizedBatchNorm2d +BatchNorm2d = SynchronizedBatchNorm2d + + +class SegmentationModuleBase(nn.Module): + def __init__(self): + super(SegmentationModuleBase, self).__init__() + + def pixel_acc(self, pred, label): + _, preds = torch.max(pred, dim=1) + valid = (label >= 0).long() + acc_sum = torch.sum(valid * (preds == label).long()) + pixel_sum = torch.sum(valid) + acc = acc_sum.float() / (pixel_sum.float() + 1e-10) + return acc + + +class SegmentationModule(SegmentationModuleBase): + def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): + super(SegmentationModule, self).__init__() + self.encoder = net_enc + self.decoder = net_dec + self.crit = crit + self.deep_sup_scale = deep_sup_scale + + def forward(self, feed_dict, *, segSize=None): + # training + if segSize is None: + if self.deep_sup_scale is not None: # use deep supervision technique + (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) + else: + pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) + + loss = self.crit(pred, feed_dict['seg_label']) + if self.deep_sup_scale is not None: + loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) + loss = loss + loss_deepsup * self.deep_sup_scale + + acc = self.pixel_acc(pred, feed_dict['seg_label']) + return loss, acc + # inference + else: + pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) + return pred + + +class ModelBuilder: + # custom weights initialization + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + #elif classname.find('Linear') != -1: + # m.weight.data.normal_(0.0, 0.0001) + + @staticmethod + def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): + pretrained = True if len(weights) == 0 else False + arch = arch.lower() + if arch == 'mobilenetv2dilated': + orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) + net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) + elif arch == 'resnet18': + orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet18dilated': + orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet34': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet34dilated': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet50dilated': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet101': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet101dilated': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnext101': + orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnext) # we can still use class Resnet + elif arch == 'hrnetv2': + net_encoder = hrnet.__dict__['hrnetv2'](pretrained=pretrained) + else: + raise Exception('Architecture undefined!') + + # encoders are usually pretrained + # net_encoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + # print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_encoder + + @staticmethod + def build_decoder(arch='ppm_deepsup', + fc_dim=512, num_class=150, + weights='', use_softmax=False): + arch = arch.lower() + if arch == 'c1_deepsup': + net_decoder = C1DeepSup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'c1': + net_decoder = C1( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'ppm': + net_decoder = PPM( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'ppm_deepsup': + net_decoder = PPMDeepsup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'upernet_lite': + net_decoder = UPerNet( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + fpn_dim=256) + elif arch == 'upernet': + net_decoder = UPerNet( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + fpn_dim=512) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + # print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_decoder + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + "3x3 convolution + BN + relu" + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=1, bias=False), + BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +class Resnet(nn.Module): + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + + +class ResnetDilated(nn.Module): + def __init__(self, orig_resnet, dilate_scale=8): + super(ResnetDilated, self).__init__() + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply( + partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=4)) + elif dilate_scale == 16: + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=2)) + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate//2, dilate//2) + m.padding = (dilate//2, dilate//2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + + +class MobileNetV2Dilated(nn.Module): + def __init__(self, orig_net, dilate_scale=8): + super(MobileNetV2Dilated, self).__init__() + from functools import partial + + # take pretrained mobilenet features + self.features = orig_net.features[:-1] + + self.total_idx = len(self.features) + self.down_idx = [2, 4, 7, 14] + + if dilate_scale == 8: + for i in range(self.down_idx[-2], self.down_idx[-1]): + self.features[i].apply( + partial(self._nostride_dilate, dilate=2) + ) + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply( + partial(self._nostride_dilate, dilate=4) + ) + elif dilate_scale == 16: + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply( + partial(self._nostride_dilate, dilate=2) + ) + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate//2, dilate//2) + m.padding = (dilate//2, dilate//2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + if return_feature_maps: + conv_out = [] + for i in range(self.total_idx): + x = self.features[i](x) + if i in self.down_idx: + conv_out.append(x) + conv_out.append(x) + return conv_out + + else: + return [self.features(x)] + + +# last conv, deep supervision +class C1DeepSup(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1DeepSup, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# last conv +class C1(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + + return x + + +# pyramid pooling +class PPM(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6)): + super(PPM, self).__init__() + self.use_softmax = use_softmax + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim+len(pool_scales)*512, 512, + kernel_size=3, padding=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + return x + + +# pyramid pooling, deep supervision +class PPMDeepsup(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6)): + super(PPMDeepsup, self).__init__() + self.use_softmax = use_softmax + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim+len(pool_scales)*512, 512, + kernel_size=3, padding=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.dropout_deepsup = nn.Dropout2d(0.1) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.dropout_deepsup(_) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# upernet +class UPerNet(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6), + fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256): + super(UPerNet, self).__init__() + self.use_softmax = use_softmax + + # PPM Module + self.ppm_pooling = [] + self.ppm_conv = [] + + for scale in pool_scales: + self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) + self.ppm_conv.append(nn.Sequential( + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm_pooling = nn.ModuleList(self.ppm_pooling) + self.ppm_conv = nn.ModuleList(self.ppm_conv) + self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) + + # FPN Module + self.fpn_in = [] + for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer + self.fpn_in.append(nn.Sequential( + nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), + BatchNorm2d(fpn_dim), + nn.ReLU(inplace=True) + )) + self.fpn_in = nn.ModuleList(self.fpn_in) + + self.fpn_out = [] + for i in range(len(fpn_inplanes) - 1): # skip the top layer + self.fpn_out.append(nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + )) + self.fpn_out = nn.ModuleList(self.fpn_out) + + self.conv_last = nn.Sequential( + conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), + nn.Conv2d(fpn_dim, num_class, kernel_size=1) + ) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): + ppm_out.append(pool_conv(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False))) + ppm_out = torch.cat(ppm_out, 1) + f = self.ppm_last_conv(ppm_out) + + fpn_feature_list = [f] + for i in reversed(range(len(conv_out) - 1)): + conv_x = conv_out[i] + conv_x = self.fpn_in[i](conv_x) # lateral branch + + f = nn.functional.interpolate( + f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch + f = conv_x + f + + fpn_feature_list.append(self.fpn_out[i](f)) + + fpn_feature_list.reverse() # [P2 - P5] + output_size = fpn_feature_list[0].size()[2:] + fusion_list = [fpn_feature_list[0]] + for i in range(1, len(fpn_feature_list)): + fusion_list.append(nn.functional.interpolate( + fpn_feature_list[i], + output_size, + mode='bilinear', align_corners=False)) + fusion_out = torch.cat(fusion_list, 1) + x = self.conv_last(fusion_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + x = nn.functional.log_softmax(x, dim=1) + + return x diff --git a/imcui/third_party/gim/networks/mit_semseg/models/resnet.py b/imcui/third_party/gim/networks/mit_semseg/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b5cc981a925feba1db76bdb3e3b99e05472ab508 --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/models/resnet.py @@ -0,0 +1,216 @@ +import torch.nn as nn +import math +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d +BatchNorm2d = SynchronizedBatchNorm2d + + +__all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! + + +model_urls = { + 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', + 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', + 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet18'])) + return model + +''' +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet34'])) + return model +''' + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet50']), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet101']), strict=False) + return model + +# def resnet152(pretrained=False, **kwargs): +# """Constructs a ResNet-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on ImageNet +# """ +# model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnet152'])) +# return model diff --git a/imcui/third_party/gim/networks/mit_semseg/models/resnext.py b/imcui/third_party/gim/networks/mit_semseg/models/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e260cd24dcdcee552e3ff0acac0c2ac7bd3adc --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/models/resnext.py @@ -0,0 +1,163 @@ +import torch.nn as nn +import math +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d +BatchNorm2d = SynchronizedBatchNorm2d + + +__all__ = ['ResNeXt', 'resnext101'] # support resnext 101 + + +model_urls = { + #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', + 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class GroupBottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): + super(GroupBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 2) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNeXt(nn.Module): + + def __init__(self, block, layers, groups=32, num_classes=1000): + self.inplanes = 128 + super(ResNeXt, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) + self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) + self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) + self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(1024 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, groups=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, groups, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +''' +def resnext50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext50']), strict=False) + return model +''' + + +def resnext101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext101']), strict=False) + return model + + +# def resnext152(pretrained=False, **kwargs): +# """Constructs a ResNeXt-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on Places +# """ +# model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnext152'])) +# return model diff --git a/imcui/third_party/gim/networks/mit_semseg/models/utils.py b/imcui/third_party/gim/networks/mit_semseg/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7301cbdbcc395adb110184f299fc47a3ce9a8716 --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/models/utils.py @@ -0,0 +1,18 @@ +import sys +import os +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve +import torch + + +def load_url(url, model_dir='./pretrained', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) diff --git a/imcui/third_party/gim/networks/mit_semseg/utils.py b/imcui/third_party/gim/networks/mit_semseg/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..600e91de91ba0fd93d29d0e03bd70652a65f3e92 --- /dev/null +++ b/imcui/third_party/gim/networks/mit_semseg/utils.py @@ -0,0 +1,200 @@ +import sys +import os +import logging +import re +import functools +import fnmatch +import numpy as np + + +def setup_logger(distributed_rank=0, filename="log.txt"): + logger = logging.getLogger("Logger") + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + ch.setFormatter(logging.Formatter(fmt)) + logger.addHandler(ch) + + return logger + + +def find_recursive(root_dir, ext='.jpg'): + files = [] + for root, dirnames, filenames in os.walk(root_dir): + for filename in fnmatch.filter(filenames, '*' + ext): + files.append(os.path.join(root, filename)) + return files + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.initialized = False + self.val = None + self.avg = None + self.sum = None + self.count = None + + def initialize(self, val, weight): + self.val = val + self.avg = val + self.sum = val * weight + self.count = weight + self.initialized = True + + def update(self, val, weight=1): + if not self.initialized: + self.initialize(val, weight) + else: + self.add(val, weight) + + def add(self, val, weight): + self.val = val + self.sum += val * weight + self.count += weight + self.avg = self.sum / self.count + + def value(self): + return self.val + + def average(self): + return self.avg + + +def unique(ar, return_index=False, return_inverse=False, return_counts=False): + ar = np.asanyarray(ar).flatten() + + optional_indices = return_index or return_inverse + optional_returns = optional_indices or return_counts + + if ar.size == 0: + if not optional_returns: + ret = ar + else: + ret = (ar,) + if return_index: + ret += (np.empty(0, np.bool),) + if return_inverse: + ret += (np.empty(0, np.bool),) + if return_counts: + ret += (np.empty(0, np.intp),) + return ret + if optional_indices: + perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') + aux = ar[perm] + else: + ar.sort() + aux = ar + flag = np.concatenate(([True], aux[1:] != aux[:-1])) + + if not optional_returns: + ret = aux[flag] + else: + ret = (aux[flag],) + if return_index: + ret += (perm[flag],) + if return_inverse: + iflag = np.cumsum(flag) - 1 + inv_idx = np.empty(ar.shape, dtype=np.intp) + inv_idx[perm] = iflag + ret += (inv_idx,) + if return_counts: + idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) + ret += (np.diff(idx),) + return ret + + +def colorEncode(labelmap, colors, mode='RGB'): + labelmap = labelmap.astype('int') + labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), + dtype=np.uint8) + for label in unique(labelmap): + if label < 0: + continue + labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ + np.tile(colors[label], + (labelmap.shape[0], labelmap.shape[1], 1)) + + if mode == 'BGR': + return labelmap_rgb[:, :, ::-1] + else: + return labelmap_rgb + + +def accuracy(preds, label): + valid = (label >= 0) + acc_sum = (valid * (preds == label)).sum() + valid_sum = valid.sum() + acc = float(acc_sum) / (valid_sum + 1e-10) + return acc, valid_sum + + +def intersectionAndUnion(imPred, imLab, numClass): + imPred = np.asarray(imPred).copy() + imLab = np.asarray(imLab).copy() + + imPred += 1 + imLab += 1 + # Remove classes from unlabeled pixels in gt image. + # We should not penalize detections in unlabeled portions of the image. + imPred = imPred * (imLab > 0) + + # Compute area intersection: + intersection = imPred * (imPred == imLab) + (area_intersection, _) = np.histogram( + intersection, bins=numClass, range=(1, numClass)) + + # Compute area union: + (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) + (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) + area_union = area_pred + area_lab - area_intersection + + return (area_intersection, area_union) + + +class NotSupportedCliException(Exception): + pass + + +def process_range(xpu, inp): + start, end = map(int, inp) + if start > end: + end, start = start, end + return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) + + +REGEX = [ + (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), + (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), + (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), + functools.partial(process_range, 'gpu')), + (re.compile(r'^(\d+)-(\d+)$'), + functools.partial(process_range, 'gpu')), +] + + +def parse_devices(input_devices): + + """Parse user's devices input str to standard format. + e.g. [gpu0, gpu1, ...] + + """ + ret = [] + for d in input_devices.split(','): + for regex, func in REGEX: + m = regex.match(d.lower().strip()) + if m: + tmp = func(m.groups()) + # prevent duplicate + for x in tmp: + if x not in ret: + ret.append(x) + break + else: + raise NotSupportedCliException( + 'Can not recognize device: "{}"'.format(d)) + return ret diff --git a/imcui/third_party/gim/reconstruction.py b/imcui/third_party/gim/reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e1126e132eeef7881d185e088bb493c39672f3 --- /dev/null +++ b/imcui/third_party/gim/reconstruction.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import os +import torch +import warnings +import numpy as np + +from tqdm import tqdm +from os.path import join +from pathlib import Path +from argparse import ArgumentParser + +from hloc import pairs_from_exhaustive +from hloc import extract_features, match_features, match_dense, reconstruction + +from hloc.utils import segment +from hloc.utils.io import read_image +from hloc.match_dense import ImagePairDataset + +from networks.lightglue.superpoint import SuperPoint +from networks.lightglue.models.matchers.lightglue import LightGlue +from networks.mit_semseg.models import ModelBuilder, SegmentationModule + + +def segmentation(images, segment_root, matcher_conf): + # initial device + device = 'cuda' if torch.cuda.is_available() else 'cpu' + # initial segmentation mode + net_encoder = ModelBuilder.build_encoder( + arch='resnet50dilated', + fc_dim=2048, + weights='weights/encoder_epoch_20.pth') + net_decoder = ModelBuilder.build_decoder( + arch='ppm_deepsup', + fc_dim=2048, + num_class=150, + weights='weights/decoder_epoch_20.pth', + use_softmax=True) + crit = torch.nn.NLLLoss(ignore_index=-1) + segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) + segmentation_module = segmentation_module.to(device).eval() + # initial data reader + dataset = ImagePairDataset(None, matcher_conf["preprocessing"], None) + # Segment images + image_list = sorted(os.listdir(images)) + with torch.no_grad(): + for img in tqdm(image_list): + segment_path = join(segment_root, '{}.npy'.format(img[:-4])) + if not os.path.exists(segment_path): + rgb = read_image(images / img, dataset.conf.grayscale) + mask = segment(rgb, 1920, device, segmentation_module) + np.save(segment_path, mask) + + +def main(scene_name, version): + # Setup + images = Path('inputs') / scene_name / 'images' + + outputs = Path('outputs') / scene_name / version + outputs.mkdir(parents=True, exist_ok=True) + os.environ['GIMRECONSTRUCTION'] = str(outputs) + + segment_root = Path('outputs') / scene_name / 'segment' + segment_root.mkdir(parents=True, exist_ok=True) + + sfm_dir = outputs / 'sparse' + mvs_path = outputs / 'dense' + database_path = sfm_dir / 'database.db' + image_pairs = outputs / 'pairs-near.txt' + + feature_conf = matcher_conf = None + + if version == 'gim_dkm': + feature_conf = None + matcher_conf = match_dense.confs[version] + elif version == 'gim_lightglue': + feature_conf = extract_features.confs['gim_superpoint'] + matcher_conf = match_features.confs[version] + + # Find image pairs via pair-wise image + exhaustive_pairs = pairs_from_exhaustive.main(image_pairs, image_list=images) + + segmentation(images, segment_root, matcher_conf) + + # Extract and match local features + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + if version == 'gim_dkm': + feature_path, match_path = match_dense.main(matcher_conf, image_pairs, + images, outputs) + elif version == 'gim_lightglue': + checkpoints_path = join('weights', 'gim_lightglue_100h.ckpt') + + detector = SuperPoint({ + 'max_num_keypoints': 2048, + 'force_num_keypoints': True, + 'detection_threshold': 0.0, + 'nms_radius': 3, + 'trainable': False, + }) + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict.pop(k) + if k.startswith('superpoint.'): + state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) + detector.load_state_dict(state_dict) + + model = LightGlue({ + 'filter_threshold': 0.1, + 'flash': False, + 'checkpointed': True, + }) + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('superpoint.'): + state_dict.pop(k) + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + model.load_state_dict(state_dict) + + feature_path = extract_features.main(feature_conf, images, outputs, + model=detector) + match_path = match_features.main(matcher_conf, image_pairs, + feature_conf['output'], outputs, + model=model) + + # sparse reconstruction + reconstruction.main(sfm_dir, images, image_pairs, feature_path, match_path) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--scene_name', type=str) + parser.add_argument('--version', type=str, choices={'gim_dkm', 'gim_lightglue'}, + default='gim_dkm') + args = parser.parse_args() + + main(args.scene_name, args.version) diff --git a/imcui/third_party/gim/test.py b/imcui/third_party/gim/test.py new file mode 100644 index 0000000000000000000000000000000000000000..9082a3555f323d2c274a7a42e285d3497e18f061 --- /dev/null +++ b/imcui/third_party/gim/test.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import cv2 +import math +import uuid + +import pytorch_lightning as pl + +from pathlib import Path +from os.path import join, exists +from argparse import ArgumentParser +from yacs.config import CfgNode as CN +from pytorch_lightning.plugins import DDPPlugin +from pytorch_lightning.loggers import TensorBoardLogger + +import tools as com + +from trainer import Trainer +from networks.loftr.configs.outdoor import trainer_cfg, network_cfg +from networks.loftr.config import get_cfg_defaults as get_network_cfg +from trainer.config import get_cfg_defaults as get_trainer_cfg +from trainer.debug import get_cfg_defaults as get_debug_cfg + +from datasets.data import MultiSceneDataModule +from datasets import gl3d +from datasets import gtasfm +from datasets import multifov +from datasets import blendedmvs +from datasets import iclnuim +from datasets import scenenet +from datasets import eth3d +from datasets import kitti +from datasets import robotcar + +Benchmarks = dict( + GL3D = gl3d.cfg, + GTASfM = gtasfm.cfg, + MultiFoV = multifov.cfg, + BlendedMVS = blendedmvs.cfg, + ICLNUIM = iclnuim.cfg, + SceneNet = scenenet.cfg, + ETH3DO = eth3d.cfgO, + ETH3DI = eth3d.cfgI, + KITTI = kitti.cfg, + RobotcarNight = robotcar.night, + RobotcarSeason = robotcar.season, + RobotcarWeather = robotcar.weather, +) + +RANSACs = dict( + RANSAC = cv2.RANSAC, + FAST = cv2.USAC_FAST, + MAGSAC = cv2.USAC_MAGSAC, + PROSAC = cv2.USAC_PROSAC, + DEFAULT = cv2.USAC_DEFAULT, + ACCURATE = cv2.USAC_ACCURATE, + PARALLEL = cv2.USAC_PARALLEL, +) + +MODEL_ZOO = ['gim_dkm', 'gim_loftr', 'gim_lightglue', 'root_sift'] + + +if __name__ == '__main__': + # ------------ + # Hyperparameters + # ------------ + parser = ArgumentParser() + + # Project args + parser.add_argument('--trains', type=str, choices=set(Benchmarks), nargs='+', + default=[], + help=f'Train Datasets: {set(Benchmarks)}', ) + parser.add_argument('--valids', type=str, choices=set(Benchmarks), nargs='+', + default=[], + help=f'Valid Datasets: {set(Benchmarks)}', ) + parser.add_argument('--tests', type=str, choices=set(Benchmarks), + default=None, + help=f'Test Datasets: {set(Benchmarks)}', ) + parser.add_argument('--debug', action='store_true', + help='For debug mode') + + # Loader args + parser.add_argument('--batch_size', type=int, default=12, + help='input batch size for training and validation (default=2)') + parser.add_argument('--threads', type=int, default=3, + help='Number of threads (default: 3)') + + # Traner args + parser.add_argument('--gpus', type=int, default=1, + help='GPU numbers') + parser.add_argument('--num_nodes', type=int, default=1, + help='Cluster node numbers') + parser.add_argument('--max_epochs', type=int, default=30, + help='Traning epochs (default: 30)') + parser.add_argument("--git", type=str, default='xxxxxx', + help=f'Git ID',) + parser.add_argument("--weight", type=str, default=None, choices=MODEL_ZOO, + required=True, + help=f'Pretrained model weight',) + + # Hyper-parameters + parser.add_argument('--img_size', type=int, default=9999, + help='Image Size') + parser.add_argument('--lr', type=float, default=8e-3, + help='Learning rate') + + # Runtime args + parser.add_argument('--test', action='store_true', + help="Tesing") + parser.add_argument('--viz', action='store_true', + help="Tesing") + + parser.add_argument("--max_samples", type=int, default=None, + help=f'Max Samples in Testing',) + parser.add_argument("--min_score", type=float, default=0.0, + help='Min Score in Testing',) + parser.add_argument("--max_score", type=float, default=1.0, + help='Max Score in Testing',) + + parser.add_argument("--ransac_threshold", type=float, default=0.5, + help='RANSAC Threshold',) + parser.add_argument('--ransac', type=str, choices=set(RANSACs), default='MAGSAC', + help=f'RANSAC Methods: {set(RANSACs)}', ) + parser.add_argument("--version", type=str, default='AUC', + help=f'Model version',) + + args = parser.parse_args() + + # ------------ + # Project config + # ------------ + pcfg = CN(vars(args)) + tcfg = get_trainer_cfg() + ncfg = get_network_cfg() + dcfg = CN({x:Benchmarks.get(x, None) for x in set(args.trains + args.valids + [args.tests])}) + tcfg.merge_from_other_cfg(trainer_cfg) + if args.debug: tcfg.merge_from_other_cfg(get_debug_cfg()) + ncfg.merge_from_other_cfg(network_cfg) + dcfg.DF = ncfg.LOFTR.RESOLUTION[0] + + # load weight + ncfg.LOFTR.WEIGHT = join('weights', args.weight + '_' + args.version + '.ckpt') + if args.weight == 'root_sift': + ncfg.LOFTR.WEIGHT = None + + # ------------ + # Testing setting + # ------------ + if args.max_samples is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MAX_SAMPLES'] = args.max_samples + if args.min_score is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MIN_OVERLAP_SCORE'] = args.min_score + if args.max_score is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MAX_OVERLAP_SCORE'] = args.max_score + # print(dcfg) + + # ------------ + # Update Trainer Config + # ------------ + TRAINER = tcfg.TRAINER + TRAINER.TRUE_BATCH_SIZE = args.gpus * args.batch_size + TRAINER.SCALING = _scaling = TRAINER.TRUE_BATCH_SIZE / TRAINER.CANONICAL_BS + TRAINER.CANONICAL_LR = args.lr + TRAINER.TRUE_LR = TRAINER.CANONICAL_LR * _scaling + TRAINER.WARMUP_STEP = math.floor(TRAINER.WARMUP_STEP / _scaling) + TRAINER.RANSAC_PIXEL_THR = args.ransac_threshold + TRAINER.POSE_ESTIMATION_METHOD = RANSACs[args.ransac] + + # ------------ + # W&B logger + # ------------ + # com.login(args.server) + wid = str(uuid.uuid1()).split('-')[0] + com.hint('ID = {}'.format(wid)) + logger = TensorBoardLogger('tensorboard', name='test', version='test') + + # ------------ + # reproducible + # ------------ + pl.seed_everything(TRAINER.SEED, workers=True) + + # ------------ + # data loader + # ------------ + dm = MultiSceneDataModule(args, dcfg) + + # ------------ + # model + # ------------ + trainer = Trainer(pcfg, tcfg, dcfg, ncfg) + + # ------------ + # training + # ------------ + fitter = pl.Trainer.from_argparse_args( + args, + # ddp + sync_batchnorm=True, + strategy=DDPPlugin(find_unused_parameters=False), + # reproducible + benchmark=True, + deterministic=False, + # logger + enable_checkpointing=False, + logger=logger, + log_every_n_steps=TRAINER.LOG_INTERVAL, + # prepare + weights_summary='top', + val_check_interval=TRAINER.VAL_CHECK_INTERVAL, + num_sanity_val_steps=TRAINER.NUM_SANITY_VAL_STEPS, + limit_train_batches=TRAINER.LIMIT_TRAIN_BATCHES, + limit_val_batches=TRAINER.LIMIT_VALID_BATCHES, + # faster training + # amp_level=TRAINER.AMP_LEVEL, + # amp_backend=TRAINER.AMP_BACKEND, + # precision=TRAINER.PRECISION, #https://github.com/PyTorchLightning/pytorch-lightning/issues/5558 + # better fine-tune + gradient_clip_val=TRAINER.GRADIENT_CLIP_VAL, + gradient_clip_algorithm=TRAINER.GRADIENT_CLIP_ALGORITHM, + ) + + # ------------ + # Fitting + # ------------ + if args.test: + scene = Path(dcfg[pcfg["tests"]]['DATASET']['TESTS']['LIST_PATH']).stem.split('_')[0] + path = f"dump/zeb/[T] {pcfg.weight} {scene:>15} {pcfg.version}.txt" + if exists(path): + print(f"{path} already exists") + exit(0) + elif not exists(str(Path(path).parent)): + Path(path).parent.mkdir(parents=True) + fitter.test(trainer, datamodule=dm) + else: + fitter.fit(trainer, datamodule=dm) diff --git a/imcui/third_party/gim/tools/__init__.py b/imcui/third_party/gim/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d82525793bf47937f3cb11c272489c9c084ca4c --- /dev/null +++ b/imcui/third_party/gim/tools/__init__.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import os +import time +import yaml +import torch +import random +import numpy as np + + +project_name = os.path.basename(os.getcwd()) + + +def make_reproducible(iscuda, seed=0): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if iscuda: + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # set True will make data load faster + # but, it will influence reproducible + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + +def hint(msg): + timestamp = f'{time.strftime("%m/%d %H:%M:%S", time.localtime(time.time()))}' + print('\033[1m' + project_name + ' >> ' + timestamp + ' >> ' + '\033[0m' + msg) + + +def datainfo(infos, datalen, gpuid): + if gpuid != 0: return + # print informations about benchmarks + print('') + print(f'{" Benchmarks":14}|{" Sequence":20}|{" Count":8}') + print(f'{"-" * 45}') + for k0, v0 in infos.items(): + isfirst = True + for k1, v1 in v0.items(): + line = f' {k0:13}|' if isfirst else f'{" " * 14}|' + line += f' {k1:19}|' + line += f' {str(v1):7}' + print(line) + print(f'{"-" * 45}') + isfirst = False + print(f'{" " * 37}{str(datalen)}') + print(f'{"-" * 45}') + print('') + + +# noinspection PyTypeChecker +def mesh_positions(h: int, w: int): + gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w)) + gx, gy = gx.contiguous()[None, :], gy.contiguous()[None, :] + pos = torch.cat((gx.view(1, -1), gy.view(1, -1))) # [2, H*W] + return pos + + +def current_time(f=None): + """ + :param f: default for log, "f" for file name + :return: formatted time + """ + if f == "f": + return f'{time.strftime("%m.%d_%H.%M.%S", time.localtime(time.time()))}' + return f'{time.strftime("%m/%d %H:%M:%S", time.localtime(time.time()))}' + + +def mkdir(dir): + if not os.path.isdir(dir): + os.makedirs(dir, exist_ok=False) + + +def pdist(x, y=None): + """ + Pairwise Distance + Args: + x: [bs, n, 2] + y: [bs, n, 2] + Returns: [bs, n, n] value in euclidean *square* distance + """ + # B, n, two = x.shape + x = x.double() # [bs, n, 2] + + x_norm = (x ** 2).sum(-1, keepdim=True) # [bs, n, 1] + if y is not None: + y = y.double() + y_t = y.transpose(1, 2) # [bs, 2, n] + y_norm = (y ** 2).sum(-1, keepdim=True).transpose(1, 2) # [bs, 1, n] + else: + y_t = x.transpose(1, 2) # [bs, 2, n] + y_norm = x_norm.transpose(1, 2) # [bs, 1, n] + + dist = x_norm + y_norm - 2.0 * torch.matmul(x, y_t) # [bs, n, n] + return dist + + +mean = lambda lis: sum(lis) / len(lis) +eps = lambda x: x + 1e-8 + + +def load_configs(configs): + with open(configs, 'r') as stream: + try: + x = yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) + return x + + +def find_in_dir(run, dir): + runs = os.listdir(dir) + runs = [r for r in runs if run in r] + if len(runs) <= 0: + hint(f'Not exist run name contain : {run}') + exit(-1) + elif len(runs) >= 2: + hint(f'{len(runs)} runs name contain : {run}') + hint(f'I will return the first one : {runs[-1]}') + else: + hint(f'Success match {runs[-1]}') + return runs[-1] + + +def ckpt_in_dir(key, dir): + runs = os.listdir(dir) + runs = [r for r in runs if key in r] + if len(runs) <= 0: + hint(f'Not exist run name contain : {key}') + exit(-1) + elif len(runs) >= 2: + hint(f'{len(runs)} runs name contain : {key}') + hint(f'I will return the first one : {runs[-1]}') + else: + hint(f'Success match {runs[-1]}') + return runs[-1] + + +def kpts2grid(kpts, scale, size): + """ + change coordinates for keypoints from size0 to size1 + and format as grid which coordinates from [-1, 1] + Args: + kpts: (b, n, 2) - (x, y) + scale: (b, 2) - (w, h) - the keypoints working shape to unet working shape + size: (b, 2) - (h, w) - the unet working shape which is 'resize0/1' in data + Returns: new kpts: (b, 1, n, 2) - (x, y) in [-1, 1] + """ + # kpts coordinates in unet shape + kpts /= scale[:,None,:] + # kpts[:,:,0] - (b, n) + kpts[:, :, 0] *= 2 / (size[:, 1][:, None] - 1) + kpts[:, :, 1] *= 2 / (size[:, 0][:, None] - 1) + # make kpts from [0, 2] to [-1, 1] + kpts -= 1 + # assume all kpts in [-1, 1] + kpts = kpts.clamp(min=-1, max=1) # (b, n, 2) + # make kpts shape from (b, n, 2) to (b, 1, n, 2) + kpts = kpts[:,None] + + return kpts + + +def debug(x): + if 'DATASET' in list(x.keys()): + y = x.DATASET + y.TRAIN.LIST_PATH = y.TRAIN.LIST_PATH.replace('scene_list', 'scene_list_debug') + y.VALID.LIST_PATH = y.VALID.LIST_PATH.replace('scene_list', 'scene_list_debug') + return x + + +def summary_loss(loss_list): + n = 0 + sums = 0 + for loss in loss_list: + if (loss is not None) and (not torch.isnan(loss)): + sums += loss + n += 1 + sums = sums / n if n != 0 else None + return sums + + +def summary_metrics(dic, h1, h2): + print('') + + # Head + print(f'RunID {h1:9}', end='') + print(' | ', end='') + print(f'Version {h2:10}', end='') + + # Content + print(f'{"| ".join(f"{key:10}" for key in dic[0].keys())}') + for metric in dic: + print(f'{"-" * 12 * len(dic[0].keys())}') + print(f'{"| ".join(f"{metric[key]:<10.5f}" for key in metric.keys())}') + + print('') + + +def get_padding_size(image, h, w): + orig_width = image.shape[3] + orig_height = image.shape[2] + aspect_ratio = w / h + + new_width = max(orig_width, int(orig_height * aspect_ratio)) + new_height = max(orig_height, int(orig_width / aspect_ratio)) + + pad_height = new_height - orig_height + pad_width = new_width - orig_width + + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + + return orig_width, orig_height, pad_left, pad_right, pad_top, pad_bottom diff --git a/imcui/third_party/gim/tools/comm.py b/imcui/third_party/gim/tools/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167 --- /dev/null +++ b/imcui/third_party/gim/tools/comm.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +[Copied from detectron2] +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/imcui/third_party/gim/tools/metrics.py b/imcui/third_party/gim/tools/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bb311a00f0b92a6742a06c45800e3d73bd90ea --- /dev/null +++ b/imcui/third_party/gim/tools/metrics.py @@ -0,0 +1,214 @@ +import cv2 +import torch +import numpy as np +from collections import OrderedDict +from kornia.geometry.epipolar import numeric +from kornia.geometry.conversions import convert_points_to_homogeneous + + +# --- METRICS --- + +def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): + # angle error between 2 vectors + t_gt = T_0to1[:3, 3] + n = np.linalg.norm(t) * np.linalg.norm(t_gt) + t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) + t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity + if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging + t_err = 0 + + r = np.linalg.norm(t_gt) / np.linalg.norm(t) + t_err2 = np.linalg.norm((t*r - t_gt)) + + # angle error between 2 rotation matrices + R_gt = T_0to1[:3, :3] + cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 + cos = np.clip(cos, -1., 1.) # handle numercial errors + R_err = np.rad2deg(np.abs(np.arccos(cos))) + + return t_err, R_err, t_err2 + + +def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (torch.Tensor): [N, 2] + pts1 (torch.Tensor): [N, 2] + E (torch.Tensor): [3, 3] + K0: + K1: + """ + pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + pts0 = convert_points_to_homogeneous(pts0) + pts1 = convert_points_to_homogeneous(pts1) + + Ep0 = pts0 @ E.T # [N, 3] + p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] + Etp1 = pts1 @ E # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + + +@torch.no_grad() +def compute_symmetrical_epipolar_errors(data): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) + E_mat = Tx @ data['T_0to1'][:, :3, :3] + + m_bids = data['m_bids'] + pts0 = data['mkpts0_f'] + pts1 = data['mkpts1_f'] + + epi_errs = [] + for bs in range(Tx.size(0)): + mask = m_bids == bs + epi_errs.append(symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + epi_errs = torch.cat(epi_errs, dim=0) + + data.update({'epi_errs': epi_errs}) + + +def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + if len(kpts0) < 5: + return None + # normalize keypoints + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + # normalize ransac threshold + ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) + + # compute pose with cv2 + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) + if E is None: + # print("\nE is None while trying to recover pose.\n") + return None + + # recover pose from E + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + ret = (R, t[:, 0], mask.ravel() > 0) + best_num_inliers = n + + return ret + + +@torch.no_grad() +def compute_pose_errors(data, config): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] + "t_errs" List[float]: [N] + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.25/0.5/0.75 + conf = config.TRAINER.RANSAC_CONF # 0.999999 + iters = config.TRAINER.RANSAC_MAX_ITERS # 100000 + method = config.TRAINER.POSE_ESTIMATION_METHOD + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + data.update({'Rot': [], 'Tns': []}) + data.update({'Rot1': [], 'Tns1': []}) + data.update({'t_errs2': []}) + + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + K0 = data['K0'].cpu().numpy() + K1 = data['K1'].cpu().numpy() + T_0to1 = data['T_0to1'].cpu().numpy() + # depth0 = data['depth0'].cpu() + # depth1 = data['depth1'].cpu() + + # weights = data['weights'] + + for bs in range(K0.shape[0]): + mask = m_bids == bs + ret1 = None + ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], 0.5, conf=0.99999) + # ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], method=method, thresh=pixel_thr, conf=conf, maxIters=iters) + # weight = weights[bs][-1].cpu().numpy() + # ret = estimate_pose_w_weight(pts0[mask], pts1[mask], weight, K0[bs], K1[bs], pixel_thr, conf=conf) + + if ret is None: + data['R_errs'].append(np.inf) + data['t_errs'].append(np.inf) + data['t_errs2'].append(np.inf) + data['inliers'].append(np.array([]).astype(bool)) + data['Rot'].append(np.eye(3)) + data['Tns'].append(np.zeros(3)) + else: + R, t, inliers = ret + t_err, R_err, t_err2 = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) + data['R_errs'].append(R_err) + data['t_errs'].append(t_err) + data['t_errs2'].append(t_err2) + data['inliers'].append(inliers) + data['Rot'].append(R) + data['Tns'].append(t) + + if ret1 is None: + data['Rot1'].append(np.eye(3)) + data['Tns1'].append(np.zeros(3)) + else: + # noinspection PyTupleAssignmentBalance + R1, t1, inliers = ret1 + data['Rot1'].append(R1) + data['Tns1'].append(t1) + + +def error_auc(errs, thres): + if isinstance(errs, list): errs = np.array(errs) + pass_ratio = [np.sum(errs < th) / len(errs) for th in thres] + # mAP = {f'AUC@{t}':np.mean(pass_ratio[:i+1]) for i, t in enumerate(thres)} + mAP = {f'AUC@{t}':pass_ratio[i] for i, t in enumerate(thres)} + return mAP + + +def epidist_prec(errors, thresholds, ret_dict=False): + precs = [] + for thr in thresholds: + prec_ = [] + for errs in errors: + correct_mask = errs < thr + prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) + precs.append(np.mean(prec_) if len(prec_) > 0 else 0) + if ret_dict: + return {f'Prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + else: + return precs + + +def aggregate_metrics(metrics, epi_err_thr=5e-4, test=False): + """ Aggregate metrics for the whole dataset: + (This method should be called once per dataset) + 1. AUC of the pose error (angular) at the threshold [5, 10, 20] + 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) + """ + # filter duplicates + unq_ids = OrderedDict((iden, i) for i, iden in enumerate(metrics['identifiers'])) + unq_ids = list(unq_ids.values()) + + # pose auc + angular_thresholds = [5, 10, 20] + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) + + # matching precision + dist_thresholds = [epi_err_thr] + precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) + + metric = {**aucs, **precs} + metric = {**metric, **{'Num': len(unq_ids)}} if test else metric + return metric diff --git a/imcui/third_party/gim/tools/misc.py b/imcui/third_party/gim/tools/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..61cd57bf1e4e5aacab58e42e9277a4ad12990dc9 --- /dev/null +++ b/imcui/third_party/gim/tools/misc.py @@ -0,0 +1,100 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() diff --git a/imcui/third_party/gim/trainer/__init__.py b/imcui/third_party/gim/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5af0cb07a7c451ac0d085fbe57d8f445c5f3c08b --- /dev/null +++ b/imcui/third_party/gim/trainer/__init__.py @@ -0,0 +1 @@ +from .lightning import Trainer \ No newline at end of file diff --git a/imcui/third_party/gim/trainer/config.py b/imcui/third_party/gim/trainer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e8040b0972b8efe2f6c7d7beb0f1918b0544c8 --- /dev/null +++ b/imcui/third_party/gim/trainer/config.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from yacs.config import CfgNode as CN + +_CN = CN() + +# ------------ +# Trainer +# ------------ +_CN.TRAINER = CN() +_CN.TRAINER.SEED = 3407 +_CN.TRAINER.NUM_SANITY_VAL_STEPS = -1 +_CN.TRAINER.LOG_INTERVAL = 20 +_CN.TRAINER.VAL_CHECK_INTERVAL = 1.0 # default 1.0, if we set 2.0 will val each 2 step +_CN.TRAINER.LIMIT_TRAIN_BATCHES = 1.0 # default 1.0 +_CN.TRAINER.LIMIT_VALID_BATCHES = 1.0 # default 1.0 will use all training batch +_CN.TRAINER.AMP_LEVEL = 'O1' # 'O1' for apex +_CN.TRAINER.AMP_BACKEND = 'apex' # 'O1' for apex +_CN.TRAINER.PRECISION = 16 # default 32 +_CN.TRAINER.GRADIENT_CLIP_VAL = 0.5 # default 0.0 +_CN.TRAINER.GRADIENT_CLIP_ALGORITHM = 'norm' # default 'norm' + +# optimizer +_CN.TRAINER.CANONICAL_BS = 64 +_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] +_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime +_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAMW_DECAY = 0.1 +# step-based warm-up +_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_STEP = 4800 +# learning rate scheduler +_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR +_CN.TRAINER.MSLR_GAMMA = 0.5 +_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing +_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval +# geometric metrics and pose solver +_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.RANSAC_PIXEL_THR = None +_CN.TRAINER.RANSAC_CONF = 0.999999 +_CN.TRAINER.RANSAC_MAX_ITERS = 100000 +_CN.TRAINER.USE_MAGSACPP = False + +# Related to Visualization +_CN.VISUAL = CN() +_CN.VISUAL.N_VAL_PAIRS_TO_PLOT = 10 +_CN.VISUAL.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] +_CN.VISUAL.PLOT_MATCHES_ALPHA = 'dynamic' + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/imcui/third_party/gim/trainer/debug.py b/imcui/third_party/gim/trainer/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..0952849a3780d5a136d41ea3af8edd2760a8183f --- /dev/null +++ b/imcui/third_party/gim/trainer/debug.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +from yacs.config import CfgNode as CN + +_CN = CN() + +# ------------ +# Trainer +# ------------ +_CN.TRAINER = CN() +_CN.TRAINER.NUM_SANITY_VAL_STEPS = 0 +_CN.TRAINER.LOG_INTERVAL = 1 +_CN.TRAINER.VAL_CHECK_INTERVAL = 1.0 # default 1.0, if we set 2.0 will val each 2 step +_CN.TRAINER.LIMIT_TRAIN_BATCHES = 10.0 # default 1.0 +_CN.TRAINER.LIMIT_VALID_BATCHES = 10.0 # default 1.0 will use all training batch + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/imcui/third_party/gim/trainer/lightning.py b/imcui/third_party/gim/trainer/lightning.py new file mode 100644 index 0000000000000000000000000000000000000000..c1be4464bebde2a85a6fd013044d84703edaa5c4 --- /dev/null +++ b/imcui/third_party/gim/trainer/lightning.py @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import cv2 +import torch +import numpy as np +import pytorch_lightning as pl + +from pathlib import Path +from collections import OrderedDict + +from tools.comm import all_gather +from tools.misc import lower_config, flattenList +from tools.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors + + +class Trainer(pl.LightningModule): + + def __init__(self, pcfg, tcfg, dcfg, ncfg): + super().__init__() + + self.save_hyperparameters() + self.pcfg = pcfg + self.tcfg = tcfg + self.ncfg = ncfg + ncfg = lower_config(ncfg) + + detector = model = None + if pcfg.weight == 'gim_dkm': + from networks.dkm.models.model_zoo.DKMv3 import DKMv3 + detector = None + model = DKMv3(None, 540, 720, upsample_preds=True) + model.h_resized = 660 + model.w_resized = 880 + model.upsample_preds = True + model.upsample_res = (1152, 1536) + model.use_soft_mutual_nearest_neighbours = False + elif pcfg.weight == 'gim_loftr': + from networks.loftr.loftr import LoFTR as MODEL + detector = None + model = MODEL(ncfg['loftr']) + elif pcfg.weight == 'gim_lightglue': + from networks.lightglue.superpoint import SuperPoint + from networks.lightglue.models.matchers.lightglue import LightGlue + detector = SuperPoint({ + 'max_num_keypoints': 2048, + 'force_num_keypoints': True, + 'detection_threshold': 0.0, + 'nms_radius': 3, + 'trainable': False, + }) + model = LightGlue({ + 'filter_threshold': 0.1, + 'flash': False, + 'checkpointed': True, + }) + elif pcfg.weight == 'root_sift': + detector = None + model = None + + self.detector = detector + self.model = model + + checkpoints_path = ncfg['loftr']['weight'] + if ncfg['loftr']['weight'] is not None: + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + + if pcfg.weight == 'gim_dkm': + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + if 'encoder.net.fc' in k: + state_dict.pop(k) + elif pcfg.weight == 'gim_lightglue': + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict.pop(k) + if k.startswith('superpoint.'): + state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) + self.detector.load_state_dict(state_dict) + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('superpoint.'): + state_dict.pop(k) + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + + self.model.load_state_dict(state_dict) + print('Load weights {} success'.format(ncfg['loftr']['weight'])) + + def compute_metrics(self, batch): + compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match + compute_pose_errors(batch, self.tcfg) # compute R_errs, t_errs, pose_errs for each pair + + rel_pair_names = list(zip(batch['scene_id'], *batch['pair_names'])) + bs = batch['image0'].size(0) + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'R_errs': batch['R_errs'], + 't_errs': batch['t_errs'], + 'inliers': batch['inliers'], + 'covisible0': batch['covisible0'], + 'covisible1': batch['covisible1'], + 'Rot': batch['Rot'], + 'Tns': batch['Tns'], + 'Rot1': batch['Rot1'], + 'Tns1': batch['Tns1'], + 't_errs2': batch['t_errs2'], + } + return metrics + + def inference(self, data): + if self.pcfg.weight == 'gim_dkm': + self.gim_dkm_inference(data) + elif self.pcfg.weight == 'gim_loftr': + self.gim_loftr_inference(data) + elif self.pcfg.weight == 'gim_lightglue': + self.gim_lightglue_inference(data) + elif self.pcfg.weight == 'root_sift': + self.root_sift_inference(data) + + def gim_dkm_inference(self, data): + dense_matches, dense_certainty = self.model.match(data['color0'], data['color1']) + sparse_matches, mconf = self.model.sample(dense_matches, dense_certainty, 5000) + hw0_i = data['color0'].shape[2:] + hw1_i = data['color1'].shape[2:] + height0, width0 = data['imsize0'][0] + height1, width1 = data['imsize1'][0] + kpts0 = sparse_matches[:, :2] + kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1,) + kpts1 = sparse_matches[:, 2:] + kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1,) + + b_ids = torch.where(mconf[None])[0] + mask = mconf > 0 + + data.update({ + 'hw0_i': hw0_i, + 'hw1_i': hw1_i, + 'mkpts0_f': kpts0[mask], + 'mkpts1_f': kpts1[mask], + 'm_bids': b_ids, + 'mconf': mconf[mask], + }) + + def gim_loftr_inference(self, data): + self.model(data) + + def gim_lightglue_inference(self, data): + hw0_i = data['color0'].shape[2:] + hw1_i = data['color1'].shape[2:] + + pred = {} + pred.update({k+'0': v for k, v in self.detector({ + "image": data["image0"], + "image_size": data["resize0"][:, [1, 0]], + }).items()}) + pred.update({k+'1': v for k, v in self.detector({ + "image": data["image1"], + "image_size": data["resize1"][:, [1, 0]], + }).items()}) + pred.update(self.model({**pred, **data})) + + bs = data['image0'].size(0) + mkpts0_f = torch.cat([kp * s for kp, s in zip(pred['keypoints0'], data['scale0'][:, None])]) + mkpts1_f = torch.cat([kp * s for kp, s in zip(pred['keypoints1'], data['scale1'][:, None])]) + m_bids = torch.nonzero(pred['keypoints0'].sum(dim=2) > -1)[:, 0] + matches = pred['matches'] + mkpts0_f = torch.cat([mkpts0_f[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) + mkpts1_f = torch.cat([mkpts1_f[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) + m_bids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) + mconf = torch.cat(pred['scores']) + + data.update({ + 'hw0_i': hw0_i, + 'hw1_i': hw1_i, + 'mkpts0_f': mkpts0_f, + 'mkpts1_f': mkpts1_f, + 'm_bids': m_bids, + 'mconf': mconf, + }) + + def root_sift_inference(self, data): + # matching two images by sift + image0 = data['color0'].squeeze().permute(1, 2, 0).cpu().numpy() * 255 + image1 = data['color1'].squeeze().permute(1, 2, 0).cpu().numpy() * 255 + + image0 = cv2.cvtColor(image0.astype(np.uint8), cv2.COLOR_RGB2BGR) + image1 = cv2.cvtColor(image1.astype(np.uint8), cv2.COLOR_RGB2BGR) + + H0, W0 = image0.shape[:2] + H1, W1 = image1.shape[:2] + + sift0 = cv2.SIFT_create(nfeatures=H0*W0//64, contrastThreshold=1e-5) + sift1 = cv2.SIFT_create(nfeatures=H1*W1//64, contrastThreshold=1e-5) + + kpts0, desc0 = sift0.detectAndCompute(image0, None) + kpts1, desc1 = sift1.detectAndCompute(image1, None) + kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0]) + kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1]) + + kpts0, desc0, kpts1, desc1 = map(lambda x: torch.from_numpy(x).cuda().float(), [kpts0, desc0, kpts1, desc1]) + desc0, desc1 = map(lambda x: (x / x.sum(dim=1, keepdim=True)).sqrt(), [desc0, desc1]) + + matches = desc0 @ desc1.transpose(0, 1) + + mask = (matches == matches.max(dim=1, keepdim=True).values) & \ + (matches == matches.max(dim=0, keepdim=True).values) + valid, indices = mask.max(dim=1) + ratio = torch.topk(matches, k=2, dim=1).values + # noinspection PyUnresolvedReferences + ratio = (-2 * ratio + 2).sqrt() + ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8 + valid = valid & ratio + + kpts0 = kpts0[valid] * data['scale0'] + kpts1 = kpts1[indices[valid]] * data['scale1'] + mconf = matches.max(dim=1).values[valid] + + b_ids = torch.where(valid[None])[0] + + data.update({ + 'hw0_i': data['image0'].shape[2:], + 'hw1_i': data['image1'].shape[2:], + 'mkpts0_f': kpts0, + 'mkpts1_f': kpts1, + 'm_bids': b_ids, + 'mconf': mconf, + }) + + def test_step(self, batch, batch_idx): + self.inference(batch) + metrics = self.compute_metrics(batch) + return {'Metrics': metrics} + + def test_epoch_end(self, outputs): + + metrics = [o['Metrics'] for o in outputs] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in metrics]))) for k in metrics[0]} + + unq_ids = list(OrderedDict((iden, i) for i, iden in enumerate(metrics['identifiers'])).values()) + ord_ids = sorted(unq_ids, key=lambda x:metrics['identifiers'][x]) + metrics = {k:[v[x] for x in ord_ids] for k,v in metrics.items()} + # ['identifiers', 'epi_errs', 'R_errs', 't_errs', 'inliers', + # 'covisible0', 'covisible1', 'Rot', 'Tns', 'Rot1', 'Tns1'] + output = '' + output += 'identifiers covisible0 covisible1 R_errs t_errs t_errs2 ' + output += 'Bef.Prec Bef.Num Aft.Prec Aft.Num\n' + eet = 5e-4 # epi_err_thr + mean = lambda x: sum(x) / max(len(x), 1) + for ids, epi, Rer, Ter, Ter2, inl, co0, co1 in zip( + metrics['identifiers'], metrics['epi_errs'], + metrics['R_errs'], metrics['t_errs'], metrics['t_errs2'], metrics['inliers'], + metrics['covisible0'], metrics['covisible1']): + bef = epi < eet + aft = epi[inl] < eet + output += f'{ids} {co0} {co1} {Rer} {Ter} {Ter2} ' + output += f'{mean(bef)} {sum(bef)} {mean(aft)} {sum(aft)}\n' + + scene = Path(self.hparams['dcfg'][self.pcfg["tests"]]['DATASET']['TESTS']['LIST_PATH']).stem.split('_')[0] + path = f"dump/zeb/[T] {self.pcfg.weight} {scene:>15} {self.pcfg.version}.txt" + with open(path, 'w') as file: + file.write(output) diff --git a/imcui/third_party/gim/video_preprocessor.py b/imcui/third_party/gim/video_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..b2749f350fad8ab6f16016fb09e02dd12f1f849e --- /dev/null +++ b/imcui/third_party/gim/video_preprocessor.py @@ -0,0 +1,751 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun +import os + +import cv2 +import csv +import math +import torch +import scipy.io +import warnings +import argparse +import numpy as np + +from os import mkdir +from tqdm import tqdm +from copy import deepcopy +from os.path import join, exists +from torch.utils.data import DataLoader + +from datasets.walk.video_streamer import VideoStreamer +from datasets.walk.video_loader import WALKDataset, collate_fn + +from networks.mit_semseg.models import ModelBuilder, SegmentationModule + +gray2tensor = lambda x: (torch.from_numpy(x).float() / 255)[None, None] +color2tensor = lambda x: (torch.from_numpy(x).float() / 255).permute(2, 0, 1)[None] + +warnings.simplefilter("ignore", category=UserWarning) + +methods = {'SIFT', 'GIM_GLUE', 'GIM_LOFTR', 'GIM_DKM'} + +PALETTE = scipy.io.loadmat('weights/color150.mat')['colors'] + +CLS_DICT = {} # {'person': 13, 'sky': 3} +with open('weights/object150_info.csv') as f: + reader = csv.reader(f) + next(reader) + for row in reader: + name = row[5].split(";")[0] + if name == 'screen': + name = '_'.join(row[5].split(";")[:2]) + CLS_DICT[name] = int(row[0]) - 1 + +exclude = ['person', 'sky', 'car'] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + parser.add_argument("--gpu", type=int, + default=0, help='-1 for CPU') + parser.add_argument("--range", type=int, nargs='+', + default=None, + help='Video Range for seconds') + parser.add_argument('--scene_name', type=str, + default=None, + help='Scene (video) name') + parser.add_argument('--method', type=str, choices=methods, + required=True, + help='Method name') + parser.add_argument('--resize', action='store_true', + help='whether resize') + parser.add_argument('--skip', type=int, + required=True, + help='Video skip frame: 1, 2, 3, ...') + parser.add_argument('--watermarker', type=int, nargs='+', + default=None, + help='Watermarker Rectangle Range') + opt = parser.parse_args() + + data_root = join('data', 'ZeroMatch') + video_name = opt.scene_name.strip() + video_path = join(data_root, 'video_1080p', video_name + '.mp4') + + # get real size of video + vcap = cv2.VideoCapture(video_path) + vwidth = vcap.get(3) # float `width` + vheight = vcap.get(4) # float `height` + fps = vcap.get(5) # float `fps` + end_range = math.floor(vcap.get(cv2.CAP_PROP_FRAME_COUNT) / fps - 300) + vcap.release() + + fps = math.ceil(fps) + opt.range = [300, end_range] if opt.range is None else opt.range + opt.range = [0, -1] if video_name == 'Od-rKbC30TM' else opt.range # for demo + + if fps <= 30: + skip = [10, 20, 40][opt.skip] + else: + skip = [20, 40, 80][opt.skip] + + dump_dir = join(data_root, 'pseudo', + 'WALK ' + opt.method + + ' [R] ' + '{}'.format('T' if opt.resize else 'F') + + ' [S] ' + '{:2}'.format(skip)) + if not exists(dump_dir): mkdir(dump_dir) + debug_dir = join('dump', video_name + ' ' + opt.method) + if opt.resize: debug_dir = debug_dir + ' Resize' + if opt.debug and (not exists(debug_dir)): mkdir(debug_dir) + + # start process video + gap = 10 if fps <= 30 else 20 + vs = VideoStreamer(basedir=video_path, resize=opt.resize, df=8, skip=gap, vrange=opt.range) + + # read the first frame + rgb = vs[vs.listing[0]] + width, height = rgb.shape[1], rgb.shape[0] + + # calculate ratio + vratio = np.array([vwidth / width, vheight / height])[None] + + # set dump name + scene_name = f'{video_name} ' + scene_name += f'WH {width:4} {height:4} ' + scene_name += f'RG {vs.range[0]:4} {vs.range[1]:4} ' + scene_name += f'SP {skip} ' + scene_name += f'{len(video_name)}' + + save_dir = join(dump_dir, scene_name) + + device = torch.device('cuda:{}'.format(opt.gpu)) if opt.gpu >= 0 else torch.device('cpu') + + # initialize segmentation model + net_encoder = ModelBuilder.build_encoder( + arch='resnet50dilated', + fc_dim=2048, + weights='weights/encoder_epoch_20.pth') + net_decoder = ModelBuilder.build_decoder( + arch='ppm_deepsup', + fc_dim=2048, + num_class=150, + weights='weights/decoder_epoch_20.pth', + use_softmax=True) + crit = torch.nn.NLLLoss(ignore_index=-1) + segmentation_module = SegmentationModule(net_encoder, net_decoder, crit).to(device).eval() + old_segment_root = join(data_root, 'segment', opt.scene_name) + new_segment_root = join(data_root, 'segment', opt.scene_name.strip()) + if not os.path.exists(new_segment_root): + if os.path.exists(old_segment_root): + os.rename(old_segment_root, new_segment_root) + else: + os.makedirs(new_segment_root, exist_ok=True) + segment_root = new_segment_root + + model, detectAndCompute = None, None + + if opt.method == 'SIFT': + model = cv2.SIFT_create(nfeatures=32400, contrastThreshold=1e-5) + detectAndCompute = model.detectAndCompute + + elif opt.method == 'GIM_DKM': + from networks.dkm.models.model_zoo.DKMv3 import DKMv3 + model = DKMv3(weights=None, h=672, w=896) + checkpoints_path = join('weights', 'gim_dkm_100h.ckpt') + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + if 'encoder.net.fc' in k: + state_dict.pop(k) + model.load_state_dict(state_dict) + model = model.eval().to(device) + + elif opt.method == 'GIM_LOFTR': + from networks.loftr.loftr import LoFTR + from networks.loftr.misc import lower_config + from networks.loftr.config import get_cfg_defaults + + cfg = get_cfg_defaults() + cfg.TEMP_BUG_FIX = True + cfg.LOFTR.WEIGHT = 'weights/gim_loftr_50h.ckpt' + cfg.LOFTR.FINE_CONCAT_COARSE_FEAT = False + cfg = lower_config(cfg) + model = LoFTR(cfg['loftr']) + model = model.to(device) + model = model.eval() + + elif opt.method == 'GIM_GLUE': + from networks.lightglue.matching import Matching + + model = Matching() + + checkpoints_path = join('weights', 'gim_lightglue_100h.ckpt') + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict.pop(k) + if k.startswith('superpoint.'): + state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) + model.detector.load_state_dict(state_dict) + + state_dict = torch.load(checkpoints_path, map_location='cpu') + if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] + for k in list(state_dict.keys()): + if k.startswith('superpoint.'): + state_dict.pop(k) + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + model.model.load_state_dict(state_dict) + + model = model.to(device) + model = model.eval() + + cache_dir = None + if opt.resize: + cache_dir = join(data_root, 'pseudo', + 'WALK ' + 'GIM_DKM' + + ' [R] F' + + ' [S] ' + '{:2}'.format(skip), + scene_name) + + _w_ = width if opt.method == 'SIFT' or opt.method == 'GLUE' else 1600 # TODO: confirm DKM + _h_ = height if opt.method == 'SIFT' or opt.method == 'GLUE' else 900 # TODO: confirm DKM + + ids = list(zip(vs.listing[:-skip // gap], vs.listing[skip // gap:])) + + # start matching and make pseudo labels + nums = None + idxs = None + checkpoint = 0 + if not opt.debug: + if exists(join(save_dir, 'nums.npy')) and exists(join(save_dir, 'idxs.npy')): + with open(join(save_dir, 'nums.npy'), 'rb') as f: + nums = np.load(f) + with open(join(save_dir, 'idxs.npy'), 'rb') as f: + idxs = np.load(f) + assert len(nums) == len(idxs) == (len(os.listdir(save_dir)) - 2) + whole = [str(x) + '.npy' for x in np.array(ids)] + cache = [str(x) + '.npy' for x in idxs] + leave = list(set(whole) - set(cache)) + if len(leave): + leave = list(map(lambda x: int(x.rsplit('[')[-1].strip().split()[0]), leave)) + skip_id = np.array(sorted(leave)) + skip_id = (skip_id[1:] - skip_id[:-1]) // gap + len_id = len(skip_id) + if len_id == 0: exit(0) + skip_id = [i for i in range(len_id) if skip_id[i:].sum() == (len_id - i)] + if len(skip_id) == 0: exit(0) + skip_id = skip_id[0] + checkpoint = np.where(np.array(ids)[:, 0]==sorted(leave)[skip_id])[0][0] + if len(nums) + skip_id > checkpoint: exit(0) + assert checkpoint == len(nums) + skip_id + else: + exit(0) + else: + if not exists(save_dir): mkdir(save_dir) + nums = np.array([]) + idxs = np.array([]) + datasets = WALKDataset(data_root, vs=vs, ids=ids, checkpoint=checkpoint, opt=opt) + loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 5, + 'pin_memory': True, 'drop_last': False} + loader = DataLoader(datasets, collate_fn=collate_fn, **loader_params) + for i, batch in enumerate(tqdm(loader, ncols=120, bar_format="{l_bar}{bar:3}{r_bar}", + desc='{:11} - [{:5}, {:2}{}]'.format(video_name[:40], opt.method, skip, '*' if opt.resize else ''), + total=len(loader), leave=False)): + idx = batch['idx'].item() + assert i == idx + idx0 = batch['idx0'].item() + idx1 = batch['idx1'].item() + assert idx0 == ids[idx+checkpoint][0] and idx1 == ids[idx+checkpoint][1] + + # cache loaded image + if not batch['rgb0_is_good'].item(): + img_path0 = batch['img_path0'][0] + if not os.path.exists(img_path0): + cv2.imwrite(img_path0, batch['rgb0'].squeeze(0).numpy()) + if not batch['rgb1_is_good'].item(): + img_path1 = batch['img_path1'][0] + if not os.path.exists(img_path1): + cv2.imwrite(img_path1, batch['rgb1'].squeeze(0).numpy()) + + current_id = np.array([idx0, idx1]) + save_name = '{}.npy'.format(str(current_id)) + save_path = join(save_dir, save_name) + if exists(save_path) and not opt.debug: continue + + rgb0 = batch['rgb0'].squeeze(0).numpy() + rgb1 = batch['rgb1'].squeeze(0).numpy() + _rgb0_, _rgb1_ = deepcopy(rgb0), deepcopy(rgb1) + + # get correspondeces in unresize image + pt0, pt1 = None, None + if opt.resize: + cache_path = join(cache_dir, save_name) + if not exists(cache_path): continue + with open(cache_path, 'rb') as f: + pts = np.load(f) + pt0, pt1 = pts[:, :2], pts[:, 2:] + + # process first frame image + xA0, xA1, yA0, yA1, hA, wA, wA_new, hA_new = None, None, None, None, None, None, None, None + if opt.resize: + # crop rgb0 + xA0 = math.floor(pt0[:, 0].min()) + xA1 = math.ceil(pt0[:, 0].max()) + yA0 = math.floor(pt0[:, 1].min()) + yA1 = math.ceil(pt0[:, 1].max()) + rgb0 = rgb0[yA0:yA1, xA0:xA1] + hA, wA = rgb0.shape[:2] + wA_new, hA_new = get_resized_wh(wA, hA, [_h_, _w_]) + wA_new, hA_new = get_divisible_wh(wA_new, hA_new, 8) + rgb0 = cv2.resize(rgb0, (wA_new, hA_new), interpolation=cv2.INTER_AREA) + + # go on + gray0 = cv2.cvtColor(rgb0, cv2.COLOR_RGB2GRAY) + # semantic segmentation + with torch.no_grad(): + seg_path0 = join(segment_root, '{}.npy'.format(idx0)) + if not os.path.exists(seg_path0): + mask0 = segment(_rgb0_, device, segmentation_module) + np.save(seg_path0, mask0) + else: + mask0 = np.load(seg_path0) + + # process next frame image + xB0, xB1, yB0, yB1, hB, wB, wB_new, hB_new = None, None, None, None, None, None, None, None + if opt.resize: + # crop rgb1 + xB0 = math.floor(pt1[:, 0].min()) + xB1 = math.ceil(pt1[:, 0].max()) + yB0 = math.floor(pt1[:, 1].min()) + yB1 = math.ceil(pt1[:, 1].max()) + rgb1 = rgb1[yB0:yB1, xB0:xB1] + hB, wB = rgb1.shape[:2] + wB_new, hB_new = get_resized_wh(wB, hB, [_h_, _w_]) + wB_new, hB_new = get_divisible_wh(wB_new, hB_new, 8) + rgb1 = cv2.resize(rgb1, (wB_new, hB_new), interpolation=cv2.INTER_AREA) + + # go on + gray1 = cv2.cvtColor(rgb1, cv2.COLOR_RGB2GRAY) + # semantic segmentation + with torch.no_grad(): + seg_path1 = join(segment_root, '{}.npy'.format(idx1)) + if not os.path.exists(seg_path1): + mask1 = segment(_rgb1_, device, segmentation_module) + np.save(seg_path1, mask1) + else: + mask1 = np.load(seg_path1) + + if mask0.shape[:2] != _rgb0_.shape[:2]: + mask0 = cv2.resize(mask0, _rgb0_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) + + if mask1.shape != _rgb1_.shape[:2]: + mask1 = cv2.resize(mask1, _rgb1_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) + + if opt.resize: + # resize mask0 + mask0 = mask0[yA0:yA1, xA0:xA1] + mask0 = cv2.resize(mask0, (wA_new, hA_new), interpolation=cv2.INTER_NEAREST) + # resize mask1 + mask1 = mask1[yB0:yB1, xB0:xB1] + mask1 = cv2.resize(mask1, (wB_new, hB_new), interpolation=cv2.INTER_NEAREST) + + data = None + if opt.method == 'SIFT': + + mask_0 = mask0 != CLS_DICT[exclude[0]] + mask_1 = mask1 != CLS_DICT[exclude[0]] + for cls in exclude[1:]: + mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) + mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) + mask_0 = mask_0.astype(np.uint8) + mask_1 = mask_1.astype(np.uint8) + + if mask_0.sum() == 0 or mask_1.sum() == 0: continue + + # keypoint detection and description + kpts0, desc0 = detectAndCompute(rgb0, mask_0) + if desc0 is None or desc0.shape[0] < 8: continue + kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0]) + kpts0, desc0 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts0, desc0]) + desc0 = (desc0 / desc0.sum(dim=1, keepdim=True)).sqrt() + + # keypoint detection and description + kpts1, desc1 = detectAndCompute(rgb1, mask_1) + if desc1 is None or desc1.shape[0] < 8: continue + kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1]) + kpts1, desc1 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts1, desc1]) + desc1 = (desc1 / desc1.sum(dim=1, keepdim=True)).sqrt() + + # mutual nearest matching and ratio filter + matches = desc0 @ desc1.transpose(0, 1) + mask = (matches == matches.max(dim=1, keepdim=True).values) & \ + (matches == matches.max(dim=0, keepdim=True).values) + # noinspection PyUnresolvedReferences + valid, indices = mask.max(dim=1) + ratio = torch.topk(matches, k=2, dim=1).values + ratio = (-2 * ratio + 2).sqrt() + # ratio = (ratio[:, 0] / ratio[:, 1]) < opt.mt + ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8 + valid = valid & ratio + + # get matched keypoints + mkpts0 = kpts0[valid] + mkpts1 = kpts1[indices[valid]] + b_ids = torch.where(valid[None])[0] + + data = dict( + m_bids = b_ids, + mkpts0_f = mkpts0, + mkpts1_f = mkpts1, + ) + + elif opt.method == 'GIM_DKM': + + mask_0 = mask0 != CLS_DICT[exclude[0]] + mask_1 = mask1 != CLS_DICT[exclude[0]] + for cls in exclude[1:]: + mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) + mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) + mask_0 = mask_0.astype(np.uint8) + mask_1 = mask_1.astype(np.uint8) + + if mask_0.sum() == 0 or mask_1.sum() == 0: continue + + img0 = rgb0 * mask_0[..., None] + img1 = rgb1 * mask_1[..., None] + + width0, height0 = img0.shape[1], img0.shape[0] + width1, height1 = img1.shape[1], img1.shape[0] + + with torch.no_grad(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + img0 = torch.from_numpy(img0).permute(2, 0, 1).to(device)[None] / 255 + img1 = torch.from_numpy(img1).permute(2, 0, 1).to(device)[None] / 255 + dense_matches, dense_certainty = model.match(img0, img1) + sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000) + mkpts0 = sparse_matches[:, :2] + mkpts0 = torch.stack((width0 * (mkpts0[:, 0] + 1) / 2, + height0 * (mkpts0[:, 1] + 1) / 2), dim=-1) + mkpts1 = sparse_matches[:, 2:] + mkpts1 = torch.stack((width1 * (mkpts1[:, 0] + 1) / 2, + height1 * (mkpts1[:, 1] + 1) / 2), dim=-1) + m_bids = torch.zeros(sparse_matches.shape[0], dtype=torch.long, device=device) + + data = dict( + m_bids = m_bids, + mkpts0_f = mkpts0, + mkpts1_f = mkpts1, + ) + + elif opt.method == 'GIM_LOFTR': + + mask_0 = mask0 != CLS_DICT[exclude[0]] + mask_1 = mask1 != CLS_DICT[exclude[0]] + for cls in exclude[1:]: + mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) + mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) + mask_0 = mask_0.astype(np.uint8) + mask_1 = mask_1.astype(np.uint8) + + if mask_0.sum() == 0 or mask_1.sum() == 0: continue + + mask_0 = cv2.resize(mask_0, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST) + mask_1 = cv2.resize(mask_1, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST) + + data = dict( + image0=gray2tensor(gray0), + image1=gray2tensor(gray1), + color0=color2tensor(rgb0), + color1=color2tensor(rgb1), + mask0=torch.from_numpy(mask_0)[None], + mask1=torch.from_numpy(mask_1)[None], + ) + + with torch.no_grad(): + data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v + in data.items()} + model(data) + + elif opt.method == 'GIM_GLUE': + + mask_0 = mask0 != CLS_DICT[exclude[0]] + mask_1 = mask1 != CLS_DICT[exclude[0]] + for cls in exclude[1:]: + mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) + mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) + mask_0 = mask_0.astype(np.uint8) + mask_1 = mask_1.astype(np.uint8) + + if mask_0.sum() == 0 or mask_1.sum() == 0: continue + + size0 = torch.tensor(gray0.shape[-2:][::-1])[None] + size1 = torch.tensor(gray1.shape[-2:][::-1])[None] + data = dict( + gray0 = gray2tensor(gray0 * mask_0), + gray1 = gray2tensor(gray1 * mask_1), + size0 = size0, + size1 = size1, + ) + + with torch.no_grad(): + data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v + in data.items()} + pred = model(data) + kpts0, kpts1 = pred['keypoints0'][0], pred['keypoints1'][0] + matches = pred['matches'][0] + if len(matches) == 0: continue + + mkpts0 = kpts0[matches[..., 0]] + mkpts1 = kpts1[matches[..., 1]] + m_bids = torch.zeros(matches[..., 0].size(), dtype=torch.long, device=device) + + data = dict( + m_bids = m_bids, + mkpts0_f = mkpts0, + mkpts1_f = mkpts1, + ) + + # auto remove watermarker + kpts0 = data['mkpts0_f'].clone() # (N, 2) + kpts1 = data['mkpts1_f'].clone() # (N, 2) + moved = ~((kpts0 - kpts1).abs() < 1).min(dim=1).values # (N) + data['m_bids'] = data['m_bids'][moved] + data['mkpts0_f'] = data['mkpts0_f'][moved] + data['mkpts1_f'] = data['mkpts1_f'][moved] + + robust_fitting(data) + if (data['inliers'] is None) or (sum(data['inliers'][0]) == 0): continue + + inliers = data['inliers'][0] + + if opt.debug: + data.update(dict( + # for debug visualization + mask0 = mask0, + mask1 = mask1, + gray0 = gray0, + gray1 = gray1, + color0 = rgb0, + color1 = rgb1, + hw0_i = rgb0.shape[:2], + hw1_i = rgb1.shape[:2], + dataset_name = ['WALK'], + scene_id = [video_name], + pair_id = [[idx0, idx1]], + imsize0=[[width, height]], + imsize1=[[width, height]], + )) + out = fast_make_matching_robust_fitting_figure(data) + cv2.imwrite(join(debug_dir, '{} {:8d} {:8d}.png'.format(scene_name, idx0, idx1)), + cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) + continue + + if opt.resize: + mkpts0_f = (data['mkpts0_f'].cpu().numpy()[inliers] * np.array([[wA/wA_new, hA/hA_new]]) + np.array([[xA0, yA0]])) * vratio + mkpts1_f = (data['mkpts1_f'].cpu().numpy()[inliers] * np.array([[wB/wB_new, hB/hB_new]]) + np.array([[xB0, yB0]])) * vratio + else: + mkpts0_f = data['mkpts0_f'].cpu().numpy()[inliers] * vratio + mkpts1_f = data['mkpts1_f'].cpu().numpy()[inliers] * vratio + + pts = np.concatenate([mkpts0_f, mkpts1_f], axis=1).astype(np.float32) + nums = np.concatenate([nums, np.array([len(pts)])], axis=0) if len(nums) else np.array([len(pts)]) + idxs = np.concatenate([idxs, current_id[None]], axis=0) if len(idxs) else current_id[None] + + with open(save_path, 'wb') as f: + np.save(f, pts) + + with open(join(save_dir, 'nums.npy'), 'wb') as f: + np.save(f, nums) + + with open(join(save_dir, 'idxs.npy'), 'wb') as f: + np.save(f, idxs) + + +def robust_fitting(data, b_id=0): + m_bids = data['m_bids'].cpu().numpy() + kpts0 = data['mkpts0_f'].cpu().numpy() + kpts1 = data['mkpts1_f'].cpu().numpy() + + mask = m_bids == b_id + + # noinspection PyBroadException + try: + _, mask = cv2.findFundamentalMat(kpts0[mask], kpts1[mask], cv2.USAC_MAGSAC, ransacReprojThreshold=0.5, confidence=0.999999, maxIters=100000) + mask = (mask.ravel() > 0)[None] + except: + mask = None + + data.update(dict(inliers=mask)) + + +def get_resized_wh(w, h, resize): + nh, nw = resize + sh, sw = nh / h, nw / w + scale = min(sh, sw) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new = max((w // df), 1) * df + h_new = max((h // df), 1) * df + else: + w_new, h_new = w, h + return w_new, h_new + + +def read_deeplab_image(img, size=1920): + width, height = img.shape[1], img.shape[0] + + if max(width, height) > size: + if width > height: + img = cv2.resize(img, (size, int(size * height / width)), interpolation=cv2.INTER_AREA) + else: + img = cv2.resize(img, (int(size * width / height), size), interpolation=cv2.INTER_AREA) + + img = (torch.from_numpy(img).float() / 255).permute(2, 0, 1)[None] + + return img + + +def read_segmentation_image(img): + img = read_deeplab_image(img, size=720)[0] + img = img - torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) + img = img / torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) + return img + + +def segment(rgb, device, segmentation_module): + img_data = read_segmentation_image(rgb) + singleton_batch = {'img_data': img_data[None].to(device)} + output_size = img_data.shape[1:] + # Run the segmentation at the highest resolution. + scores = segmentation_module(singleton_batch, segSize=output_size) + # Get the predicted scores for each pixel + _, pred = torch.max(scores, dim=1) + return pred.cpu()[0].numpy().astype(np.uint8) + + +def getLabel(pair, idxs, nums, h5py_i, h5py_f): + """ + Args: + pair: [6965 6970] + idxs: (N, 2) + nums: (N,) + h5py_i: (M, 2) + h5py_f: (M, 2) + + Returns: pseudo_label (N, 4) + """ + i, j = np.where(idxs == pair) + if len(i) == 0: return None + assert (len(i) == len(j) == 2) and (i[0] == i[1]) and (j[0] == 0) and (j[1] == 1) + i = i[0] + nums = nums[:i+1] + idx0, idx1 = sum(nums[:-1]), sum(nums) + + mkpts0 = h5py_i[idx0:idx1] + mkpts1 = h5py_f[idx0:idx1] # (N, 2) + + return mkpts0, mkpts1 + + +def fast_make_matching_robust_fitting_figure(data, b_id=0): + b_mask = data['m_bids'] == b_id + + gray0 = data['gray0'] + gray1 = data['gray1'] + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + margin = 2 + (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] + h, w = max(h0, h1), max(w0, w1) + H, W = margin * 5 + h * 4, margin * 3 + w * 2 + + # canvas + out = 255 * np.ones((H, W), np.uint8) + + wx = [margin, margin + w0, margin + w + margin, margin + w + margin + w1] + hx = lambda row: margin * row + h * (row-1) + out = np.stack([out] * 3, -1) + + sh = hx(row=1) + color0 = data['color0'] # (rH, rW, 3) + color1 = data['color1'] # (rH, rW, 3) + out[sh: sh + h0, wx[0]: wx[1]] = color0 + out[sh: sh + h1, wx[2]: wx[3]] = color1 + + sh = hx(row=2) + img0 = np.stack([gray0] * 3, -1) * 0 + for cls in exclude: img0[data['mask0'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]] + out[sh: sh + h0, wx[0]: wx[1]] = img0 + img1 = np.stack([gray1] * 3, -1) * 0 + for cls in exclude: img1[data['mask1'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]] + out[sh: sh + h1, wx[2]: wx[3]] = img1 + + # before outlier filtering + sh = hx(row=3) + mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) + out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) + out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) + for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): + # display line end-points as circles + c = (230, 216, 132) + cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA) + + # after outlier filtering + if data['inliers'] is not None: + sh = hx(row=4) + inliers = data['inliers'][b_id] + mkpts0, mkpts1 = np.round(kpts0).astype(int)[inliers], np.round(kpts1).astype(int)[inliers] + out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) + out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) + for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): + # display line end-points as circles + c = (230, 216, 132) + cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA) + + # Big text. + text = [ + f' ', + f'#Matches {len(kpts0)}', + f'#Matches {sum(data["inliers"][b_id]) if data["inliers"] is not None else 0}', + ] + sc = min(H / 640., 1.0) + Ht = int(30 * sc) # text height + txt_color_fg = (255, 255, 255) # white + txt_color_bg = (0, 0, 0) # black + for i, t in enumerate(text): + cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA) + + fingerprint = [ + 'Dataset: {}'.format(data['dataset_name'][b_id]), + 'Scene ID: {}'.format(data['scene_id'][b_id]), + 'Pair ID: {}'.format(data['pair_id'][b_id]), + 'Image sizes: {} - {}'.format(data['imsize0'][b_id], + data['imsize1'][b_id]), + ] + sc = min(H / 640., 1.0) + Ht = int(18 * sc) # text height + txt_color_fg = (255, 255, 255) # white + txt_color_bg = (0, 0, 0) # black + for i, t in enumerate(reversed(fingerprint)): + cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_fg, 1, cv2.LINE_AA) + + return out + + +if __name__ == '__main__': + with torch.no_grad(): + main() diff --git a/imcui/third_party/lanet/__init__.py b/imcui/third_party/lanet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/lanet/augmentations.py b/imcui/third_party/lanet/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8a551f4b0979e714b54818a74a4e49fe07b966 --- /dev/null +++ b/imcui/third_party/lanet/augmentations.py @@ -0,0 +1,342 @@ +# From https://github.com/TRI-ML/KP2D. + +# Copyright 2020 Toyota Research Institute. All rights reserved. + +import random +from math import pi + +import cv2 +import numpy as np +import torch +import torchvision +import torchvision.transforms as transforms +from PIL import Image + +from utils import image_grid + + +def filter_dict(dict, keywords): + """ + Returns only the keywords that are part of a dictionary + + Parameters + ---------- + dictionary : dict + Dictionary for filtering + keywords : list of str + Keywords that will be filtered + + Returns + ------- + keywords : list of str + List containing the keywords that are keys in dictionary + """ + return [key for key in keywords if key in dict] + + +def resize_sample(sample, image_shape, image_interpolation=Image.ANTIALIAS): + """ + Resizes a sample, which contains an input image. + + Parameters + ---------- + sample : dict + Dictionary with sample values (output from a dataset's __getitem__ method) + shape : tuple (H,W) + Output shape + image_interpolation : int + Interpolation mode + + Returns + ------- + sample : dict + Resized sample + """ + # image + image_transform = transforms.Resize(image_shape, interpolation=image_interpolation) + sample['image'] = image_transform(sample['image']) + return sample + +def spatial_augment_sample(sample): + """ Apply spatial augmentation to an image (flipping and random affine transformation).""" + augment_image = transforms.Compose([ + transforms.RandomVerticalFlip(p=0.5), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)) + + ]) + sample['image'] = augment_image(sample['image']) + + return sample + +def unnormalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): + """ Counterpart method of torchvision.transforms.Normalize.""" + for t, m, s in zip(tensor, mean, std): + t.div_(1 / s).sub_(-m) + return tensor + + +def sample_homography( + shape, perspective=True, scaling=True, rotation=True, translation=True, + n_scales=100, n_angles=100, scaling_amplitude=0.1, perspective_amplitude=0.4, + patch_ratio=0.8, max_angle=pi/4): + """ Sample a random homography that includes perspective, scale, translation and rotation operations.""" + + width = float(shape[1]) + hw_ratio = float(shape[0]) / float(shape[1]) + + pts1 = np.stack([[-1., -1.], [-1., 1.], [1., -1.], [1., 1.]], axis=0) + pts2 = pts1.copy() * patch_ratio + pts2[:,1] *= hw_ratio + + if perspective: + + perspective_amplitude_x = np.random.normal(0., perspective_amplitude/2, (2)) + perspective_amplitude_y = np.random.normal(0., hw_ratio * perspective_amplitude/2, (2)) + + perspective_amplitude_x = np.clip(perspective_amplitude_x, -perspective_amplitude/2, perspective_amplitude/2) + perspective_amplitude_y = np.clip(perspective_amplitude_y, hw_ratio * -perspective_amplitude/2, hw_ratio * perspective_amplitude/2) + + pts2[0,0] -= perspective_amplitude_x[1] + pts2[0,1] -= perspective_amplitude_y[1] + + pts2[1,0] -= perspective_amplitude_x[0] + pts2[1,1] += perspective_amplitude_y[1] + + pts2[2,0] += perspective_amplitude_x[1] + pts2[2,1] -= perspective_amplitude_y[0] + + pts2[3,0] += perspective_amplitude_x[0] + pts2[3,1] += perspective_amplitude_y[0] + + if scaling: + + random_scales = np.random.normal(1, scaling_amplitude/2, (n_scales)) + random_scales = np.clip(random_scales, 1-scaling_amplitude/2, 1+scaling_amplitude/2) + + scales = np.concatenate([[1.], random_scales], 0) + center = np.mean(pts2, axis=0, keepdims=True) + scaled = np.expand_dims(pts2 - center, axis=0) * np.expand_dims( + np.expand_dims(scales, 1), 1) + center + valid = np.arange(n_scales) # all scales are valid except scale=1 + idx = valid[np.random.randint(valid.shape[0])] + pts2 = scaled[idx] + + if translation: + t_min, t_max = np.min(pts2 - [-1., -hw_ratio], axis=0), np.min([1., hw_ratio] - pts2, axis=0) + pts2 += np.expand_dims(np.stack([np.random.uniform(-t_min[0], t_max[0]), + np.random.uniform(-t_min[1], t_max[1])]), + axis=0) + + if rotation: + angles = np.linspace(-max_angle, max_angle, n_angles) + angles = np.concatenate([[0.], angles], axis=0) + + center = np.mean(pts2, axis=0, keepdims=True) + rot_mat = np.reshape(np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), + np.cos(angles)], axis=1), [-1, 2, 2]) + rotated = np.matmul( + np.tile(np.expand_dims(pts2 - center, axis=0), [n_angles+1, 1, 1]), + rot_mat) + center + + valid = np.where(np.all((rotated >= [-1.,-hw_ratio]) & (rotated < [1.,hw_ratio]), + axis=(1, 2)))[0] + + idx = valid[np.random.randint(valid.shape[0])] + pts2 = rotated[idx] + + pts2[:,1] /= hw_ratio + + def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] + def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] + + a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0) + p_mat = np.transpose(np.stack( + [[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)) + + homography = np.matmul(np.linalg.pinv(a_mat), p_mat).squeeze() + homography = np.concatenate([homography, [1.]]).reshape(3,3) + return homography + +def warp_homography(sources, homography): + """Warp features given a homography + + Parameters + ---------- + sources: torch.tensor (1,H,W,2) + Keypoint vector. + homography: torch.Tensor (3,3) + Homography. + + Returns + ------- + warped_sources: torch.tensor (1,H,W,2) + Warped feature vector. + """ + _, H, W, _ = sources.shape + warped_sources = sources.clone().squeeze() + warped_sources = warped_sources.view(-1,2) + warped_sources = torch.addmm(homography[:,2], warped_sources, homography[:,:2].t()) + warped_sources.mul_(1/warped_sources[:,2].unsqueeze(1)) + warped_sources = warped_sources[:,:2].contiguous().view(1,H,W,2) + return warped_sources + +def add_noise(img, mode="gaussian", percent=0.02): + """Add image noise + + Parameters + ---------- + image : np.array + Input image + mode: str + Type of noise, from ['gaussian','salt','pepper','s&p'] + percent: float + Percentage image points to add noise to. + Returns + ------- + image : np.array + Image plus noise. + """ + original_dtype = img.dtype + if mode == "gaussian": + mean = 0 + var = 0.1 + sigma = var * 0.5 + + if img.ndim == 2: + h, w = img.shape + gauss = np.random.normal(mean, sigma, (h, w)) + else: + h, w, c = img.shape + gauss = np.random.normal(mean, sigma, (h, w, c)) + + if img.dtype not in [np.float32, np.float64]: + gauss = gauss * np.iinfo(img.dtype).max + img = np.clip(img.astype(np.float) + gauss, 0, np.iinfo(img.dtype).max) + else: + img = np.clip(img.astype(np.float) + gauss, 0, 1) + + elif mode == "salt": + print(img.dtype) + s_vs_p = 1 + num_salt = np.ceil(percent * img.size * s_vs_p) + coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape]) + + if img.dtype in [np.float32, np.float64]: + img[coords] = 1 + else: + img[coords] = np.iinfo(img.dtype).max + print(img.dtype) + elif mode == "pepper": + s_vs_p = 0 + num_pepper = np.ceil(percent * img.size * (1.0 - s_vs_p)) + coords = tuple( + [np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape] + ) + img[coords] = 0 + + elif mode == "s&p": + s_vs_p = 0.5 + + # Salt mode + num_salt = np.ceil(percent * img.size * s_vs_p) + coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape]) + if img.dtype in [np.float32, np.float64]: + img[coords] = 1 + else: + img[coords] = np.iinfo(img.dtype).max + + # Pepper mode + num_pepper = np.ceil(percent * img.size * (1.0 - s_vs_p)) + coords = tuple( + [np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape] + ) + img[coords] = 0 + else: + raise ValueError("not support mode for {}".format(mode)) + + noisy = img.astype(original_dtype) + return noisy + + +def non_spatial_augmentation(img_warp_ori, jitter_paramters, color_order=[0,1,2], to_gray=False): + """ Apply non-spatial augmentation to an image (jittering, color swap, convert to gray scale, Gaussian blur).""" + + brightness, contrast, saturation, hue = jitter_paramters + color_augmentation = transforms.ColorJitter(brightness, contrast, saturation, hue) + ''' + augment_image = color_augmentation.get_params(brightness=[max(0, 1 - brightness), 1 + brightness], + contrast=[max(0, 1 - contrast), 1 + contrast], + saturation=[max(0, 1 - saturation), 1 + saturation], + hue=[-hue, hue]) + ''' + + B = img_warp_ori.shape[0] + img_warp = [] + kernel_sizes = [0,1,3,5] + for b in range(B): + img_warp_sub = img_warp_ori[b].cpu() + img_warp_sub = torchvision.transforms.functional.to_pil_image(img_warp_sub) + + img_warp_sub_np = np.array(img_warp_sub) + img_warp_sub_np = img_warp_sub_np[:,:,color_order] + + if np.random.rand() > 0.5: + img_warp_sub_np = add_noise(img_warp_sub_np) + + rand_index = np.random.randint(4) + kernel_size = kernel_sizes[rand_index] + if kernel_size >0: + img_warp_sub_np = cv2.GaussianBlur(img_warp_sub_np, (kernel_size, kernel_size), sigmaX=0) + + if to_gray: + img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_RGB2GRAY) + img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_GRAY2RGB) + + img_warp_sub = Image.fromarray(img_warp_sub_np) + img_warp_sub = color_augmentation(img_warp_sub) + + img_warp_sub = torchvision.transforms.functional.to_tensor(img_warp_sub).to(img_warp_ori.device) + + img_warp.append(img_warp_sub) + + img_warp = torch.stack(img_warp, dim=0) + return img_warp + +def ha_augment_sample(data, jitter_paramters=[0.5, 0.5, 0.2, 0.05], patch_ratio=0.7, scaling_amplitude=0.2, max_angle=pi/4): + """Apply Homography Adaptation image augmentation.""" + input_img = data['image'].unsqueeze(0) + _, _, H, W = input_img.shape + device = input_img.device + + homography = torch.from_numpy( + sample_homography([H, W], + patch_ratio=patch_ratio, + scaling_amplitude=scaling_amplitude, + max_angle=max_angle)).float().to(device) + homography_inv = torch.inverse(homography) + + source = image_grid(1, H, W, + dtype=input_img.dtype, + device=device, + ones=False, normalized=True).clone().permute(0, 2, 3, 1) + + target_warped = warp_homography(source, homography) + img_warp = torch.nn.functional.grid_sample(input_img, target_warped) + + color_order = [0,1,2] + if np.random.rand() > 0.5: + random.shuffle(color_order) + + to_gray = False + if np.random.rand() > 0.5: + to_gray = True + + input_img = non_spatial_augmentation(input_img, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray) + img_warp = non_spatial_augmentation(img_warp, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray) + + data['image'] = input_img.squeeze() + data['image_aug'] = img_warp.squeeze() + data['homography'] = homography + data['homography_inv'] = homography_inv + return data diff --git a/imcui/third_party/lanet/config.py b/imcui/third_party/lanet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..89539ad9a747b9c2e4d9ef84a290ad2a5d7c9c45 --- /dev/null +++ b/imcui/third_party/lanet/config.py @@ -0,0 +1,79 @@ +import argparse + +arg_lists = [] +parser = argparse.ArgumentParser(description='LANet') + +def str2bool(v): + return v.lower() in ('true', '1') + +def add_argument_group(name): + arg = parser.add_argument_group(name) + arg_lists.append(arg) + return arg + +# train data params +traindata_arg = add_argument_group('Traindata Params') +traindata_arg.add_argument('--train_txt', type=str, default='', + help='Train set.') +traindata_arg.add_argument('--train_root', type=str, default='', + help='Where the train images are.') +traindata_arg.add_argument('--batch_size', type=int, default=8, + help='# of images in each batch of data') +traindata_arg.add_argument('--num_workers', type=int, default=4, + help='# of subprocesses to use for data loading') +traindata_arg.add_argument('--pin_memory', type=str2bool, default=True, + help='# of subprocesses to use for data loading') +traindata_arg.add_argument('--shuffle', type=str2bool, default=True, + help='Whether to shuffle the train and valid indices') +traindata_arg.add_argument('--image_shape', type=tuple, default=(240, 320), + help='') +traindata_arg.add_argument('--jittering', type=tuple, default=(0.5, 0.5, 0.2, 0.05), + help='') + +# data storage +storage_arg = add_argument_group('Storage') +storage_arg.add_argument('--ckpt_name', type=str, default='PointModel', + help='') + +# training params +train_arg = add_argument_group('Training Params') +train_arg.add_argument('--start_epoch', type=int, default=0, + help='') +train_arg.add_argument('--max_epoch', type=int, default=12, + help='') +train_arg.add_argument('--init_lr', type=float, default=3e-4, + help='Initial learning rate value.') +train_arg.add_argument('--lr_factor', type=float, default=0.5, + help='Reduce learning rate value.') +train_arg.add_argument('--momentum', type=float, default=0.9, + help='Nesterov momentum value.') +train_arg.add_argument('--display', type=int, default=50, + help='') + +# loss function params +loss_arg = add_argument_group('Loss function Params') +loss_arg.add_argument('--score_weight', type=float, default=1., + help='') +loss_arg.add_argument('--loc_weight', type=float, default=1., + help='') +loss_arg.add_argument('--desc_weight', type=float, default=4., + help='') +loss_arg.add_argument('--corres_weight', type=float, default=.5, + help='') +loss_arg.add_argument('--corres_threshold', type=int, default=4., + help='') + +# other params +misc_arg = add_argument_group('Misc.') +misc_arg.add_argument('--use_gpu', type=str2bool, default=True, + help="Whether to run on the GPU.") +misc_arg.add_argument('--gpu', type=int, default=0, + help="Which GPU to run on.") +misc_arg.add_argument('--seed', type=int, default=1001, + help='Seed to ensure reproducibility.') +misc_arg.add_argument('--ckpt_dir', type=str, default='./checkpoints', + help='Directory in which to save model checkpoints.') + +def get_config(): + config, unparsed = parser.parse_known_args() + return config, unparsed diff --git a/imcui/third_party/lanet/data_loader.py b/imcui/third_party/lanet/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..149f66e6c23989d9a6eb92ab658188db97e33a64 --- /dev/null +++ b/imcui/third_party/lanet/data_loader.py @@ -0,0 +1,86 @@ +from PIL import Image +from torch.utils.data import Dataset, DataLoader + +from augmentations import ha_augment_sample, resize_sample, spatial_augment_sample +from utils import to_tensor_sample + +def image_transforms(shape, jittering): + def train_transforms(sample): + sample = resize_sample(sample, image_shape=shape) + sample = spatial_augment_sample(sample) + sample = to_tensor_sample(sample) + sample = ha_augment_sample(sample, jitter_paramters=jittering) + return sample + + return {'train': train_transforms} + +class GetData(Dataset): + def __init__(self, config, transforms=None): + """ + Get the list containing all images and labels. + """ + datafile = open(config.train_txt, 'r') + lines = datafile.readlines() + + dataset = [] + for line in lines: + line = line.rstrip() + data = line.split() + dataset.append(data[0]) + + self.config = config + self.dataset = dataset + self.root = config.train_root + + self.transforms = transforms + + def __getitem__(self, index): + """ + Return image'data and its label. + """ + img_path = self.dataset[index] + img_file = self.root + img_path + img = Image.open(img_file) + + # image.mode == 'L' means the image is in gray scale + if img.mode == 'L': + img_new = Image.new("RGB", img.size) + img_new.paste(img) + sample = {'image': img_new, 'idx': index} + else: + sample = {'image': img, 'idx': index} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self): + """ + Return the number of all data. + """ + return len(self.dataset) + +def get_data_loader( + config, + transforms=None, + sampler=None, + drop_last=True, + ): + """ + Return batch data for training. + """ + transforms = image_transforms(shape=config.image_shape, jittering=config.jittering) + dataset = GetData(config, transforms=transforms['train']) + + train_loader = DataLoader( + dataset, + batch_size=config.batch_size, + shuffle=config.shuffle, + sampler=sampler, + num_workers=config.num_workers, + pin_memory=config.pin_memory, + drop_last=drop_last + ) + + return train_loader diff --git a/imcui/third_party/lanet/datasets/hp_loader.py b/imcui/third_party/lanet/datasets/hp_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..960a6403cd5fc004b2caef429b72acc32cf0c291 --- /dev/null +++ b/imcui/third_party/lanet/datasets/hp_loader.py @@ -0,0 +1,106 @@ +import torch +import cv2 +import numpy as np + +from torchvision import transforms +from torch.utils.data import Dataset +from pathlib import Path + + +class PatchesDataset(Dataset): + """ + HPatches dataset class. + # Note: output_shape = (output_width, output_height) + # Note: this returns Pytorch tensors, resized to output_shape (if specified) + # Note: the homography will be adjusted according to output_shape. + + Parameters + ---------- + root_dir : str + Path to the dataset + use_color : bool + Return color images or convert to grayscale. + data_transform : Function + Transformations applied to the sample + output_shape: tuple + If specified, the images and homographies will be resized to the desired shape. + type: str + Dataset subset to return from ['i', 'v', 'all']: + i - illumination sequences + v - viewpoint sequences + all - all sequences + """ + def __init__(self, root_dir, use_color=True, data_transform=None, output_shape=None, type='all'): + super().__init__() + self.type = type + self.root_dir = root_dir + self.data_transform = data_transform + self.output_shape = output_shape + self.use_color = use_color + base_path = Path(root_dir) + folder_paths = [x for x in base_path.iterdir() if x.is_dir()] + image_paths = [] + warped_image_paths = [] + homographies = [] + for path in folder_paths: + if self.type == 'i' and path.stem[0] != 'i': + continue + if self.type == 'v' and path.stem[0] != 'v': + continue + num_images = 5 + file_ext = '.ppm' + for i in range(2, 2 + num_images): + image_paths.append(str(Path(path, "1" + file_ext))) + warped_image_paths.append(str(Path(path, str(i) + file_ext))) + homographies.append(np.loadtxt(str(Path(path, "H_1_" + str(i))))) + self.files = {'image_paths': image_paths, 'warped_image_paths': warped_image_paths, 'homography': homographies} + + def scale_homography(self, homography, original_scale, new_scale, pre): + scales = np.divide(new_scale, original_scale) + if pre: + s = np.diag(np.append(scales, 1.)) + homography = np.matmul(s, homography) + else: + sinv = np.diag(np.append(1. / scales, 1.)) + homography = np.matmul(homography, sinv) + return homography + + def __len__(self): + return len(self.files['image_paths']) + + def __getitem__(self, idx): + + def _read_image(path): + img = cv2.imread(path, cv2.IMREAD_COLOR) + if self.use_color: + return img + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + return gray + + image = _read_image(self.files['image_paths'][idx]) + + warped_image = _read_image(self.files['warped_image_paths'][idx]) + homography = np.array(self.files['homography'][idx]) + sample = {'image': image, 'warped_image': warped_image, 'homography': homography, 'index' : idx} + + # Apply transformations + if self.output_shape is not None: + sample['homography'] = self.scale_homography(sample['homography'], + sample['image'].shape[:2][::-1], + self.output_shape, + pre=False) + sample['homography'] = self.scale_homography(sample['homography'], + sample['warped_image'].shape[:2][::-1], + self.output_shape, + pre=True) + + for key in ['image', 'warped_image']: + sample[key] = cv2.resize(sample[key], self.output_shape) + if self.use_color is False: + sample[key] = np.expand_dims(sample[key], axis=2) + + transform = transforms.ToTensor() + + for key in ['image', 'warped_image']: + sample[key] = transform(sample[key]).type('torch.FloatTensor') + return sample diff --git a/imcui/third_party/lanet/datasets/prepare_coco.py b/imcui/third_party/lanet/datasets/prepare_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..96a3e94b53e5c916c1df2e1e322080abbde1f02e --- /dev/null +++ b/imcui/third_party/lanet/datasets/prepare_coco.py @@ -0,0 +1,26 @@ +import os +import argparse + +def prepare_coco(args): + train_file = open(os.path.join(args.saved_dir, args.saved_txt), 'w') + dirs = os.listdir(args.raw_dir) + + for file in dirs: + # Write training files + train_file.write('%s\n' % (file)) + + print('Data Preparation Finished.') + +if __name__ == '__main__': + arg_parser = argparse.ArgumentParser(description="coco prepareing.") + arg_parser.add_argument('--dataset', type=str, default='coco', + help='') + arg_parser.add_argument('--raw_dir', type=str, default='', + help='') + arg_parser.add_argument('--saved_dir', type=str, default='', + help='') + arg_parser.add_argument('--saved_txt', type=str, default='train2017.txt', + help='') + args = arg_parser.parse_args() + + prepare_coco(args) \ No newline at end of file diff --git a/imcui/third_party/lanet/evaluation/descriptor_evaluation.py b/imcui/third_party/lanet/evaluation/descriptor_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e1f84199d353ac5858641c8f68bc298f9d6413 --- /dev/null +++ b/imcui/third_party/lanet/evaluation/descriptor_evaluation.py @@ -0,0 +1,254 @@ +# Copyright 2020 Toyota Research Institute. All rights reserved. +# Adapted from: https://github.com/rpautrat/SuperPoint/blob/master/superpoint/evaluations/descriptor_evaluation.py + +import random +from glob import glob +from os import path as osp + +import cv2 +import numpy as np + +from utils import warp_keypoints + + +def select_k_best(points, descriptors, k): + """ Select the k most probable points (and strip their probability). + points has shape (num_points, 3) where the last coordinate is the probability. + + Parameters + ---------- + points: numpy.ndarray (N,3) + Keypoint vector, consisting of (x,y,probability). + descriptors: numpy.ndarray (N,256) + Keypoint descriptors. + k: int + Number of keypoints to select, based on probability. + Returns + ------- + + selected_points: numpy.ndarray (k,2) + k most probable keypoints. + selected_descriptors: numpy.ndarray (k,256) + Descriptors corresponding to the k most probable keypoints. + """ + sorted_prob = points[points[:, 2].argsort(), :2] + sorted_desc = descriptors[points[:, 2].argsort(), :] + start = min(k, points.shape[0]) + selected_points = sorted_prob[-start:, :] + selected_descriptors = sorted_desc[-start:, :] + return selected_points, selected_descriptors + + +def keep_shared_points(keypoints, descriptors, H, shape, keep_k_points=1000): + """ + Compute a list of keypoints from the map, filter the list of points by keeping + only the points that once mapped by H are still inside the shape of the map + and keep at most 'keep_k_points' keypoints in the image. + + Parameters + ---------- + keypoints: numpy.ndarray (N,3) + Keypoint vector, consisting of (x,y,probability). + descriptors: numpy.ndarray (N,256) + Keypoint descriptors. + H: numpy.ndarray (3,3) + Homography. + shape: tuple + Image shape. + keep_k_points: int + Number of keypoints to select, based on probability. + + Returns + ------- + selected_points: numpy.ndarray (k,2) + k most probable keypoints. + selected_descriptors: numpy.ndarray (k,256) + Descriptors corresponding to the k most probable keypoints. + """ + + def keep_true_keypoints(points, descriptors, H, shape): + """ Keep only the points whose warped coordinates by H are still inside shape. """ + warped_points = warp_keypoints(points[:, [1, 0]], H) + warped_points[:, [0, 1]] = warped_points[:, [1, 0]] + mask = (warped_points[:, 0] >= 0) & (warped_points[:, 0] < shape[0]) &\ + (warped_points[:, 1] >= 0) & (warped_points[:, 1] < shape[1]) + return points[mask, :], descriptors[mask, :] + + selected_keypoints, selected_descriptors = keep_true_keypoints(keypoints, descriptors, H, shape) + selected_keypoints, selected_descriptors = select_k_best(selected_keypoints, selected_descriptors, keep_k_points) + return selected_keypoints, selected_descriptors + + +def compute_matching_score(data, keep_k_points=1000): + """ + Compute the matching score between two sets of keypoints with associated descriptors. + + Parameters + ---------- + data: dict + Input dictionary containing: + image_shape: tuple (H,W) + Original image shape. + homography: numpy.ndarray (3,3) + Ground truth homography. + prob: numpy.ndarray (N,3) + Keypoint vector, consisting of (x,y,probability). + warped_prob: numpy.ndarray (N,3) + Warped keypoint vector, consisting of (x,y,probability). + desc: numpy.ndarray (N,256) + Keypoint descriptors. + warped_desc: numpy.ndarray (N,256) + Warped keypoint descriptors. + keep_k_points: int + Number of keypoints to select, based on probability. + + Returns + ------- + ms: float + Matching score. + """ + shape = data['image_shape'] + real_H = data['homography'] + + # Filter out predictions + keypoints = data['prob'][:, :2].T + keypoints = keypoints[::-1] + prob = data['prob'][:, 2] + keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1) + + warped_keypoints = data['warped_prob'][:, :2].T + warped_keypoints = warped_keypoints[::-1] + warped_prob = data['warped_prob'][:, 2] + warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1) + + desc = data['desc'] + warped_desc = data['warped_desc'] + + # Keeps all points for the next frame. The matching for caculating M.Score shouldnt use only in view points. + keypoints, desc = select_k_best(keypoints, desc, keep_k_points) + warped_keypoints, warped_desc = select_k_best(warped_keypoints, warped_desc, keep_k_points) + + # Match the keypoints with the warped_keypoints with nearest neighbor search + # This part needs to be done with crossCheck=False. + # All the matched pairs need to be evaluated without any selection. + bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False) + + matches = bf.match(desc, warped_desc) + matches_idx = np.array([m.queryIdx for m in matches]) + m_keypoints = keypoints[matches_idx, :] + matches_idx = np.array([m.trainIdx for m in matches]) + m_warped_keypoints = warped_keypoints[matches_idx, :] + + true_warped_keypoints = warp_keypoints(m_warped_keypoints[:, [1, 0]], np.linalg.inv(real_H))[:,::-1] + vis_warped = np.all((true_warped_keypoints >= 0) & (true_warped_keypoints <= (np.array(shape)-1)), axis=-1) + norm1 = np.linalg.norm(true_warped_keypoints - m_keypoints, axis=-1) + + correct1 = (norm1 < 3) + count1 = np.sum(correct1 * vis_warped) + score1 = count1 / np.maximum(np.sum(vis_warped), 1.0) + + matches = bf.match(warped_desc, desc) + matches_idx = np.array([m.queryIdx for m in matches]) + m_warped_keypoints = warped_keypoints[matches_idx, :] + matches_idx = np.array([m.trainIdx for m in matches]) + m_keypoints = keypoints[matches_idx, :] + + true_keypoints = warp_keypoints(m_keypoints[:, [1, 0]], real_H)[:,::-1] + vis = np.all((true_keypoints >= 0) & (true_keypoints <= (np.array(shape)-1)), axis=-1) + norm2 = np.linalg.norm(true_keypoints - m_warped_keypoints, axis=-1) + + correct2 = (norm2 < 3) + count2 = np.sum(correct2 * vis) + score2 = count2 / np.maximum(np.sum(vis), 1.0) + + ms = (score1 + score2) / 2 + + return ms + +def compute_homography(data, keep_k_points=1000): + """ + Compute the homography between 2 sets of Keypoints and descriptors inside data. + Use the homography to compute the correctness metrics (1,3,5). + + Parameters + ---------- + data: dict + Input dictionary containing: + image_shape: tuple (H,W) + Original image shape. + homography: numpy.ndarray (3,3) + Ground truth homography. + prob: numpy.ndarray (N,3) + Keypoint vector, consisting of (x,y,probability). + warped_prob: numpy.ndarray (N,3) + Warped keypoint vector, consisting of (x,y,probability). + desc: numpy.ndarray (N,256) + Keypoint descriptors. + warped_desc: numpy.ndarray (N,256) + Warped keypoint descriptors. + keep_k_points: int + Number of keypoints to select, based on probability. + + Returns + ------- + correctness1: float + correctness1 metric. + correctness3: float + correctness3 metric. + correctness5: float + correctness5 metric. + """ + shape = data['image_shape'] + real_H = data['homography'] + + # Filter out predictions + keypoints = data['prob'][:, :2].T + keypoints = keypoints[::-1] + prob = data['prob'][:, 2] + keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1) + + warped_keypoints = data['warped_prob'][:, :2].T + warped_keypoints = warped_keypoints[::-1] + warped_prob = data['warped_prob'][:, 2] + warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1) + + desc = data['desc'] + warped_desc = data['warped_desc'] + + # Keeps only the points shared between the two views + keypoints, desc = keep_shared_points(keypoints, desc, real_H, shape, keep_k_points) + warped_keypoints, warped_desc = keep_shared_points(warped_keypoints, warped_desc, np.linalg.inv(real_H), shape, + keep_k_points) + + bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) + matches = bf.match(desc, warped_desc) + matches_idx = np.array([m.queryIdx for m in matches]) + m_keypoints = keypoints[matches_idx, :] + matches_idx = np.array([m.trainIdx for m in matches]) + m_warped_keypoints = warped_keypoints[matches_idx, :] + + # Estimate the homography between the matches using RANSAC + H, _ = cv2.findHomography(m_keypoints[:, [1, 0]], + m_warped_keypoints[:, [1, 0]], cv2.RANSAC, 3, maxIters=5000) + + if H is None: + return 0, 0, 0 + + shape = shape[::-1] + + # Compute correctness + corners = np.array([[0, 0, 1], + [0, shape[1] - 1, 1], + [shape[0] - 1, 0, 1], + [shape[0] - 1, shape[1] - 1, 1]]) + real_warped_corners = np.dot(corners, np.transpose(real_H)) + real_warped_corners = real_warped_corners[:, :2] / real_warped_corners[:, 2:] + warped_corners = np.dot(corners, np.transpose(H)) + warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] + + mean_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1)) + correctness1 = float(mean_dist <= 1) + correctness3 = float(mean_dist <= 3) + correctness5 = float(mean_dist <= 5) + + return correctness1, correctness3, correctness5 diff --git a/imcui/third_party/lanet/evaluation/detector_evaluation.py b/imcui/third_party/lanet/evaluation/detector_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc8792d17a6fbb6b446f0f9f84a2b82e3cdb57c --- /dev/null +++ b/imcui/third_party/lanet/evaluation/detector_evaluation.py @@ -0,0 +1,121 @@ +# Copyright 2020 Toyota Research Institute. All rights reserved. +# Adapted from: https://github.com/rpautrat/SuperPoint/blob/master/superpoint/evaluations/detector_evaluation.py + +import random +from glob import glob +from os import path as osp + +import cv2 +import numpy as np + +from utils import warp_keypoints + + +def compute_repeatability(data, keep_k_points=300, distance_thresh=3): + """ + Compute the repeatability metric between 2 sets of keypoints inside data. + + Parameters + ---------- + data: dict + Input dictionary containing: + image_shape: tuple (H,W) + Original image shape. + homography: numpy.ndarray (3,3) + Ground truth homography. + prob: numpy.ndarray (N,3) + Keypoint vector, consisting of (x,y,probability). + warped_prob: numpy.ndarray (N,3) + Warped keypoint vector, consisting of (x,y,probability). + keep_k_points: int + Number of keypoints to select, based on probability. + distance_thresh: int + Distance threshold in pixels for a corresponding keypoint to be considered a correct match. + + Returns + ------- + N1: int + Number of true keypoints in the first image. + N2: int + Number of true keypoints in the second image. + repeatability: float + Keypoint repeatability metric. + loc_err: float + Keypoint localization error. + """ + def filter_keypoints(points, shape): + """ Keep only the points whose coordinates are inside the dimensions of shape. """ + mask = (points[:, 0] >= 0) & (points[:, 0] < shape[0]) &\ + (points[:, 1] >= 0) & (points[:, 1] < shape[1]) + return points[mask, :] + + def keep_true_keypoints(points, H, shape): + """ Keep only the points whose warped coordinates by H are still inside shape. """ + warped_points = warp_keypoints(points[:, [1, 0]], H) + warped_points[:, [0, 1]] = warped_points[:, [1, 0]] + mask = (warped_points[:, 0] >= 0) & (warped_points[:, 0] < shape[0]) &\ + (warped_points[:, 1] >= 0) & (warped_points[:, 1] < shape[1]) + return points[mask, :] + + + def select_k_best(points, k): + """ Select the k most probable points (and strip their probability). + points has shape (num_points, 3) where the last coordinate is the probability. """ + sorted_prob = points[points[:, 2].argsort(), :2] + start = min(k, points.shape[0]) + return sorted_prob[-start:, :] + + H = data['homography'] + shape = data['image_shape'] + + # # Filter out predictions + keypoints = data['prob'][:, :2].T + keypoints = keypoints[::-1] + prob = data['prob'][:, 2] + + warped_keypoints = data['warped_prob'][:, :2].T + warped_keypoints = warped_keypoints[::-1] + warped_prob = data['warped_prob'][:, 2] + + keypoints = np.stack([keypoints[0], keypoints[1]], axis=-1) + warped_keypoints = np.stack([warped_keypoints[0], warped_keypoints[1], warped_prob], axis=-1) + warped_keypoints = keep_true_keypoints(warped_keypoints, np.linalg.inv(H), shape) + + # Warp the original keypoints with the true homography + true_warped_keypoints = warp_keypoints(keypoints[:, [1, 0]], H) + true_warped_keypoints = np.stack([true_warped_keypoints[:, 1], true_warped_keypoints[:, 0], prob], axis=-1) + true_warped_keypoints = filter_keypoints(true_warped_keypoints, shape) + + # Keep only the keep_k_points best predictions + warped_keypoints = select_k_best(warped_keypoints, keep_k_points) + true_warped_keypoints = select_k_best(true_warped_keypoints, keep_k_points) + + # Compute the repeatability + N1 = true_warped_keypoints.shape[0] + N2 = warped_keypoints.shape[0] + true_warped_keypoints = np.expand_dims(true_warped_keypoints, 1) + warped_keypoints = np.expand_dims(warped_keypoints, 0) + # shapes are broadcasted to N1 x N2 x 2: + norm = np.linalg.norm(true_warped_keypoints - warped_keypoints, ord=None, axis=2) + count1 = 0 + count2 = 0 + le1 = 0 + le2 = 0 + if N2 != 0: + min1 = np.min(norm, axis=1) + correct1 = (min1 <= distance_thresh) + count1 = np.sum(correct1) + le1 = min1[correct1].sum() + if N1 != 0: + min2 = np.min(norm, axis=0) + correct2 = (min2 <= distance_thresh) + count2 = np.sum(correct2) + le2 = min2[correct2].sum() + if N1 + N2 > 0: + repeatability = (count1 + count2) / (N1 + N2) + loc_err = (le1 + le2) / (count1 + count2) + else: + repeatability = -1 + loc_err = -1 + + return N1, N2, repeatability, loc_err diff --git a/imcui/third_party/lanet/evaluation/evaluate.py b/imcui/third_party/lanet/evaluation/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9e91ee6d9cc0142ebbe8f2a3f904f6fae8434c --- /dev/null +++ b/imcui/third_party/lanet/evaluation/evaluate.py @@ -0,0 +1,84 @@ +# Copyright 2020 Toyota Research Institute. All rights reserved. + +import numpy as np +import torch +import torchvision.transforms as transforms +from tqdm import tqdm + +from evaluation.descriptor_evaluation import (compute_homography, + compute_matching_score) +from evaluation.detector_evaluation import compute_repeatability + + +def evaluate_keypoint_net(data_loader, keypoint_net, output_shape=(320, 240), top_k=300): + """Keypoint net evaluation script. + + Parameters + ---------- + data_loader: torch.utils.data.DataLoader + Dataset loader. + keypoint_net: torch.nn.module + Keypoint network. + output_shape: tuple + Original image shape. + top_k: int + Number of keypoints to use to compute metrics, selected based on probability. + use_color: bool + Use color or grayscale images. + """ + keypoint_net.eval() + keypoint_net.training = False + + conf_threshold = 0.0 + localization_err, repeatability = [], [] + correctness1, correctness3, correctness5, MScore = [], [], [], [] + + with torch.no_grad(): + for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"): + + image = sample['image'].cuda() + warped_image = sample['warped_image'].cuda() + + score_1, coord_1, desc1 = keypoint_net(image) + score_2, coord_2, desc2 = keypoint_net(warped_image) + B, _, Hc, Wc = desc1.shape + + # Scores & Descriptors + score_1 = torch.cat([coord_1, score_1], dim=1).view(3, -1).t().cpu().numpy() + score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy() + desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() + desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() + + # Filter based on confidence threshold + desc1 = desc1[score_1[:, 2] > conf_threshold, :] + desc2 = desc2[score_2[:, 2] > conf_threshold, :] + score_1 = score_1[score_1[:, 2] > conf_threshold, :] + score_2 = score_2[score_2[:, 2] > conf_threshold, :] + + # Prepare data for eval + data = {'image': sample['image'].numpy().squeeze(), + 'image_shape' : output_shape[::-1], + 'warped_image': sample['warped_image'].numpy().squeeze(), + 'homography': sample['homography'].squeeze().numpy(), + 'prob': score_1, + 'warped_prob': score_2, + 'desc': desc1, + 'warped_desc': desc2} + + # Compute repeatabilty and localization error + _, _, rep, loc_err = compute_repeatability(data, keep_k_points=top_k, distance_thresh=3) + repeatability.append(rep) + localization_err.append(loc_err) + + # Compute correctness + c1, c2, c3 = compute_homography(data, keep_k_points=top_k) + correctness1.append(c1) + correctness3.append(c2) + correctness5.append(c3) + + # Compute matching score + mscore = compute_matching_score(data, keep_k_points=top_k) + MScore.append(mscore) + + return np.mean(repeatability), np.mean(localization_err), \ + np.mean(correctness1), np.mean(correctness3), np.mean(correctness5), np.mean(MScore) diff --git a/imcui/third_party/lanet/loss_function.py b/imcui/third_party/lanet/loss_function.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8be86c41995bfdc0ec04d79ef75a6450fcf5be --- /dev/null +++ b/imcui/third_party/lanet/loss_function.py @@ -0,0 +1,156 @@ +import torch + +def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False): + """ + Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. + + Parameters + ---------- + source_des: torch.Tensor (B,256,H/8,W/8) + Source image descriptors. + target_des: torch.Tensor (B,256,H/8,W/8) + Target image descriptors. + source_points: torch.Tensor (B,H/8,W/8,2) + Source image keypoints + tar_points: torch.Tensor (B,H/8,W/8,2) + Target image keypoints + tar_points_un: torch.Tensor (B,2,H/8,W/8) + Target image keypoints unnormalized + eval_only: bool + Computes only recall without the loss. + Returns + ------- + loss: torch.Tensor + Descriptor loss. + recall: torch.Tensor + Descriptor match recall. + """ + device = source_des.device + loss = 0 + batch_size = source_des.size(0) + recall = 0. + + relax_field_size = [relax_field] + margins = [1.0] + weights = [1.0] + + isource_dense = top_kk is None + + for b_id in range(batch_size): + + if isource_dense: + ref_desc = source_des[b_id].squeeze().view(256, -1) + tar_desc = target_des[b_id].squeeze().view(256, -1) + tar_points_raw = tar_points_un[b_id].view(2, -1) + else: + top_k = top_kk[b_id].squeeze() + + n_feat = top_k.sum().item() + if n_feat < 20: + continue + + ref_desc = source_des[b_id].squeeze()[:, top_k] + tar_desc = target_des[b_id].squeeze()[:, top_k] + tar_points_raw = tar_points_un[b_id][:, top_k] + + # Compute dense descriptor distance matrix and find nearest neighbor + ref_desc = ref_desc.div(torch.norm(ref_desc, p=2, dim=0)) + tar_desc = tar_desc.div(torch.norm(tar_desc, p=2, dim=0)) + dmat = torch.mm(ref_desc.t(), tar_desc) + + dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1)) + _, idx = torch.sort(dmat, dim=1) + + + # Compute triplet loss and recall + for pyramid in range(len(relax_field_size)): + + candidates = idx.t() + + match_k_x = tar_points_raw[0, candidates] + match_k_y = tar_points_raw[1, candidates] + + tru_x = tar_points_raw[0] + tru_y = tar_points_raw[1] + + if pyramid == 0: + correct2 = (abs(match_k_x[0]-tru_x) == 0) & (abs(match_k_y[0]-tru_y) == 0) + correct2_cnt = correct2.float().sum() + recall += float(1.0 / batch_size) * (float(correct2_cnt) / float( ref_desc.size(1))) + + if eval_only: + continue + correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & (abs(match_k_y - tru_y) <= relax_field_size[pyramid]) + + incorrect_index = torch.arange(start=correct_k.shape[0]-1, end=-1, step=-1).unsqueeze(1).repeat(1,correct_k.shape[1]).to(device) + incorrect_first = torch.argmax(incorrect_index * (1 - correct_k.long()), dim=0) + + incorrect_first_index = candidates.gather(0, incorrect_first.unsqueeze(0)).squeeze() + + anchor_var = ref_desc + posource_var = tar_desc + neg_var = tar_desc[:, incorrect_first_index] + + loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss(anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid]).mul(weights[pyramid]) + + return loss, recall + + +class KeypointLoss(object): + """ + Loss function class encapsulating the location loss, the descriptor loss, and the score loss. + """ + def __init__(self, config): + self.score_weight = config.score_weight + self.loc_weight = config.loc_weight + self.desc_weight = config.desc_weight + self.corres_weight = config.corres_weight + self.corres_threshold = config.corres_threshold + + def __call__(self, data): + B, _, hc, wc = data['source_score'].shape + + loc_mat_abs = torch.abs(data['target_coord_warped'].view(B, 2, -1).unsqueeze(3) - data['target_coord'].view(B, 2, -1).unsqueeze(2)) + l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1) + l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2) + + # construct pseudo ground truth matching matrix + loc_min_mat = torch.repeat_interleave(l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1) + pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.) + neg_mask = l2_dist_loc_mat.ge(4.) + + pos_corres = - torch.log(data['confidence_matrix'][pos_mask]) + neg_corres = - torch.log(1.0 - data['confidence_matrix'][neg_mask]) + corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean() + + # corresponding distance threshold is 4 + dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data['border_mask'].view(B, hc * wc) + + # location loss + loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean() + + # desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. + desc_loss, _ = build_descriptor_loss(data['source_desc'], data['target_desc_warped'], data['target_coord_warped'].detach(), top_kk=data['border_mask'], relax_field=8) + + # score loss + target_score_associated = data['target_score'].view(B, hc * wc).gather(1, l2_dist_loc_min_index).view(B, hc, wc).unsqueeze(1) + dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data['border_mask'].unsqueeze(1) + l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1) + loc_err = l2_dist_loc_min[dist_norm_valid_mask] + + # repeatable_constrain in score loss + repeatable_constrain = ((target_score_associated[dist_norm_valid_mask] + data['source_score'][dist_norm_valid_mask]) * (loc_err - loc_err.mean())).mean() + + # consistent_constrain in score_loss + consistent_constrain = torch.nn.functional.mse_loss(data['target_score_warped'][data['border_mask'].unsqueeze(1)], data['source_score'][data['border_mask'].unsqueeze(1)]).mean() * 2 + aware_consistent_loss = torch.nn.functional.mse_loss(data['target_aware_warped'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)], data['source_aware'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)]).mean() * 2 + + score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss + + loss = self.loc_weight * loc_loss + self.desc_weight * desc_loss + self.score_weight * score_loss + self.corres_weight * corres_loss + + return loss, self.loc_weight * loc_loss, self.desc_weight * desc_loss, self.score_weight * score_loss, self.corres_weight * corres_loss + + + + diff --git a/imcui/third_party/lanet/main.py b/imcui/third_party/lanet/main.py new file mode 100644 index 0000000000000000000000000000000000000000..105d15856ac79825c747e691ab7f695ee17a1680 --- /dev/null +++ b/imcui/third_party/lanet/main.py @@ -0,0 +1,25 @@ +import torch + +from train import Trainer +from config import get_config +from utils import prepare_dirs +from data_loader import get_data_loader + +def main(config): + # ensure directories are setup + prepare_dirs(config) + + # ensure reproducibility + torch.manual_seed(config.seed) + if config.use_gpu: + torch.cuda.manual_seed(config.seed) + + # instantiate train data loaders + train_loader = get_data_loader(config=config) + + trainer = Trainer(config, train_loader=train_loader) + trainer.train() + +if __name__ == '__main__': + config, unparsed = get_config() + main(config) \ No newline at end of file diff --git a/imcui/third_party/lanet/network_v0/__init__.py b/imcui/third_party/lanet/network_v0/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/lanet/network_v0/model.py b/imcui/third_party/lanet/network_v0/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc58aed06f0c60421e8269fbe8210a100f6e8d4 --- /dev/null +++ b/imcui/third_party/lanet/network_v0/model.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torchvision.transforms as tvf + +from .modules import InterestPointModule, CorrespondenceModule + +def warp_homography_batch(sources, homographies): + """ + Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D. + + Parameters + ---------- + sources: torch.Tensor (B,H,W,C) + Keypoints vector. + homographies: torch.Tensor (B,3,3) + Homographies. + + Returns + ------- + warped_sources: torch.Tensor (B,H,W,C) + Warped keypoints vector. + """ + B, H, W, _ = sources.shape + warped_sources = [] + for b in range(B): + source = sources[b].clone() + source = source.view(-1,2) + ''' + [X, [M11, M12, M13 [x, M11*x + M12*y + M13 [M11, M12 [M13, + Y, = M21, M22, M23 * y, = M21*x + M22*y + M23 = [x, y] * M21, M22 + M23, + Z] M31, M32, M33] 1] M31*x + M32*y + M33 M31, M32].T M33] + ''' + source = torch.addmm(homographies[b,:,2], source, homographies[b,:,:2].t()) + source.mul_(1/source[:,2].unsqueeze(1)) + source = source[:,:2].contiguous().view(H,W,2) + warped_sources.append(source) + return torch.stack(warped_sources, dim=0) + +class PointModel(nn.Module): + def __init__(self, is_test=True): + super(PointModel, self).__init__() + self.is_test = is_test + self.interestpoint_module = InterestPointModule(is_test=self.is_test) + self.correspondence_module = CorrespondenceModule() + self.norm_rgb = tvf.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225]) + + def forward(self, *args): + if self.is_test: + img = args[0] + img = self.norm_rgb(img) + score, coord, desc = self.interestpoint_module(img) + return score, coord, desc + else: + source_score, source_coord, source_desc_block = self.interestpoint_module(args[0]) + target_score, target_coord, target_desc_block = self.interestpoint_module(args[1]) + + B, _, H, W = args[0].shape + B, _, hc, wc = source_score.shape + device = source_score.device + + # Normalize the coordinates from ([0, h], [0, w]) to ([0, 1], [0, 1]). + source_coord_norm = source_coord.clone() + source_coord_norm[:, 0] = (source_coord_norm[:, 0] / (float(W - 1) / 2.)) - 1. + source_coord_norm[:, 1] = (source_coord_norm[:, 1] / (float(H - 1) / 2.)) - 1. + source_coord_norm = source_coord_norm.permute(0, 2, 3, 1) + + target_coord_norm = target_coord.clone() + target_coord_norm[:, 0] = (target_coord_norm[:, 0] / (float(W - 1) / 2.)) - 1. + target_coord_norm[:, 1] = (target_coord_norm[:, 1] / (float(H - 1) / 2.)) - 1. + target_coord_norm = target_coord_norm.permute(0, 2, 3, 1) + + target_coord_warped_norm = warp_homography_batch(source_coord_norm, args[2]) + target_coord_warped = target_coord_warped_norm.clone() + + # de-normlize the coordinates + target_coord_warped[:, :, :, 0] = (target_coord_warped[:, :, :, 0] + 1) * (float(W - 1) / 2.) + target_coord_warped[:, :, :, 1] = (target_coord_warped[:, :, :, 1] + 1) * (float(H - 1) / 2.) + target_coord_warped = target_coord_warped.permute(0, 3, 1, 2) + + # Border mask + border_mask_ori = torch.ones(B, hc, wc) + border_mask_ori[:, 0] = 0 + border_mask_ori[:, hc - 1] = 0 + border_mask_ori[:, :, 0] = 0 + border_mask_ori[:, :, wc - 1] = 0 + border_mask_ori = border_mask_ori.gt(1e-3).to(device) + + oob_mask2 = target_coord_warped_norm[:, :, :, 0].lt(1) & target_coord_warped_norm[:, :, :, 0].gt(-1) & target_coord_warped_norm[:, :, :, 1].lt(1) & target_coord_warped_norm[:, :, :, 1].gt(-1) + border_mask = border_mask_ori & oob_mask2 + + # score + target_score_warped = torch.nn.functional.grid_sample(target_score, target_coord_warped_norm.detach(), align_corners=False) + + # descriptor + source_desc2 = torch.nn.functional.grid_sample(source_desc_block[0], source_coord_norm.detach()) + source_desc3 = torch.nn.functional.grid_sample(source_desc_block[1], source_coord_norm.detach()) + source_aware = source_desc_block[2] + source_desc = torch.mul(source_desc2, source_aware[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(source_desc3, source_aware[:, 1, :, :].unsqueeze(1).contiguous()) + + target_desc2 = torch.nn.functional.grid_sample(target_desc_block[0], target_coord_norm.detach()) + target_desc3 = torch.nn.functional.grid_sample(target_desc_block[1], target_coord_norm.detach()) + target_aware = target_desc_block[2] + target_desc = torch.mul(target_desc2, target_aware[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(target_desc3, target_aware[:, 1, :, :].unsqueeze(1).contiguous()) + + target_desc2_warped = torch.nn.functional.grid_sample(target_desc_block[0], target_coord_warped_norm.detach()) + target_desc3_warped = torch.nn.functional.grid_sample(target_desc_block[1], target_coord_warped_norm.detach()) + target_aware_warped = torch.nn.functional.grid_sample(target_desc_block[2], target_coord_warped_norm.detach()) + target_desc_warped = torch.mul(target_desc2_warped, target_aware_warped[:, 0, :, :].unsqueeze(1).contiguous()) + torch.mul(target_desc3_warped, target_aware_warped[:, 1, :, :].unsqueeze(1).contiguous()) + + confidence_matrix = self.correspondence_module(source_desc, target_desc) + confidence_matrix = torch.clamp(confidence_matrix, 1e-12, 1 - 1e-12) + + output = { + 'source_score': source_score, + 'source_coord': source_coord, + 'source_desc': source_desc, + 'source_aware': source_aware, + 'target_score': target_score, + 'target_coord': target_coord, + 'target_score_warped': target_score_warped, + 'target_coord_warped': target_coord_warped, + 'target_desc_warped': target_desc_warped, + 'target_aware_warped': target_aware_warped, + 'border_mask': border_mask, + 'confidence_matrix': confidence_matrix + } + + return output diff --git a/imcui/third_party/lanet/network_v0/modules.py b/imcui/third_party/lanet/network_v0/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d95860caed657830869f8a245cbbd2a1b856f8 --- /dev/null +++ b/imcui/third_party/lanet/network_v0/modules.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import image_grid + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) + + +class DilationConv3x3(nn.Module): + def __init__(self, in_channels, out_channels): + super(DilationConv3x3, self).__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=2, dilation=2, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class InterestPointModule(nn.Module): + def __init__(self, is_test=False): + super(InterestPointModule, self).__init__() + self.is_test = is_test + + self.conv1 = ConvBlock(3, 32) + self.conv2 = ConvBlock(32, 64) + self.conv3 = ConvBlock(64, 128) + self.conv4 = ConvBlock(128, 256) + + self.maxpool2x2 = nn.MaxPool2d(2, 2) + + # score head + self.score_conv = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) + self.score_norm = nn.BatchNorm2d(256) + self.score_out = nn.Conv2d(256, 3, kernel_size=3, stride=1, padding=1) + self.softmax = nn.Softmax(dim=1) + + # location head + self.loc_conv = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) + self.loc_norm = nn.BatchNorm2d(256) + self.loc_out = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + + # descriptor out + self.des_conv2 = DilationConv3x3(64, 256) + self.des_conv3 = DilationConv3x3(128, 256) + + # cross_head: + self.shift_out = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + B, _, H, W = x.shape + + x = self.conv1(x) + x = self.maxpool2x2(x) + x2 = self.conv2(x) + x = self.maxpool2x2(x2) + x3 = self.conv3(x) + x = self.maxpool2x2(x3) + x = self.conv4(x) + + B, _, Hc, Wc = x.shape + + # score head + score_x = self.score_out(self.relu(self.score_norm(self.score_conv(x)))) + aware = self.softmax(score_x[:, 0:2, :, :]) + score = score_x[:, 2, :, :].unsqueeze(1).sigmoid() + + border_mask = torch.ones(B, Hc, Wc) + border_mask[:, 0] = 0 + border_mask[:, Hc - 1] = 0 + border_mask[:, :, 0] = 0 + border_mask[:, :, Wc - 1] = 0 + border_mask = border_mask.unsqueeze(1) + score = score * border_mask.to(score.device) + + # location head + coord_x = self.relu(self.loc_norm(self.loc_conv(x))) + coord_cell = self.loc_out(coord_x).tanh() + + shift_ratio = self.shift_out(coord_x).sigmoid() * 2.0 + + step = ((H/Hc)-1) / 2. + center_base = image_grid(B, Hc, Wc, + dtype=coord_cell.dtype, + device=coord_cell.device, + ones=False, normalized=False).mul(H/Hc) + step + + coord_un = center_base.add(coord_cell.mul(shift_ratio * step)) + coord = coord_un.clone() + coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W-1) + coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H-1) + + # descriptor block + desc_block = [] + desc_block.append(self.des_conv2(x2)) + desc_block.append(self.des_conv3(x3)) + desc_block.append(aware) + + if self.is_test: + coord_norm = coord[:, :2].clone() + coord_norm[:, 0] = (coord_norm[:, 0] / (float(W-1)/2.)) - 1. + coord_norm[:, 1] = (coord_norm[:, 1] / (float(H-1)/2.)) - 1. + coord_norm = coord_norm.permute(0, 2, 3, 1) + + desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm) + desc3 = torch.nn.functional.grid_sample(desc_block[1], coord_norm) + aware = desc_block[2] + + desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul(desc3, aware[:, 1, :, :]) + desc = desc.div(torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1)) # Divide by norm to normalize. + + return score, coord, desc + + return score, coord, desc_block + + +class CorrespondenceModule(nn.Module): + def __init__(self, match_type='dual_softmax'): + super(CorrespondenceModule, self).__init__() + self.match_type = match_type + + if self.match_type == 'dual_softmax': + self.temperature = 0.1 + else: + raise NotImplementedError() + + def forward(self, source_desc, target_desc): + b, c, h, w = source_desc.size() + + source_desc = source_desc.div(torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)).view(b, -1, h*w) + target_desc = target_desc.div(torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)).view(b, -1, h*w) + + if self.match_type == 'dual_softmax': + sim_mat = torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) / self.temperature + confidence_matrix = F.softmax(sim_mat, 1) * F.softmax(sim_mat, 2) + else: + raise NotImplementedError() + + return confidence_matrix \ No newline at end of file diff --git a/imcui/third_party/lanet/network_v1/__init__.py b/imcui/third_party/lanet/network_v1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/lanet/network_v1/model.py b/imcui/third_party/lanet/network_v1/model.py new file mode 100644 index 0000000000000000000000000000000000000000..75fe96ac0f05cf6b06b3aae64e627ca730afa56b --- /dev/null +++ b/imcui/third_party/lanet/network_v1/model.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torchvision.transforms as tvf + +from .modules import InterestPointModule, CorrespondenceModule + +def warp_homography_batch(sources, homographies): + """ + Batch warp keypoints given homographies. From https://github.com/TRI-ML/KP2D. + + Parameters + ---------- + sources: torch.Tensor (B,H,W,C) + Keypoints vector. + homographies: torch.Tensor (B,3,3) + Homographies. + + Returns + ------- + warped_sources: torch.Tensor (B,H,W,C) + Warped keypoints vector. + """ + B, H, W, _ = sources.shape + warped_sources = [] + for b in range(B): + source = sources[b].clone() + source = source.view(-1,2) + ''' + [X, [M11, M12, M13 [x, M11*x + M12*y + M13 [M11, M12 [M13, + Y, = M21, M22, M23 * y, = M21*x + M22*y + M23 = [x, y] * M21, M22 + M23, + Z] M31, M32, M33] 1] M31*x + M32*y + M33 M31, M32].T M33] + ''' + source = torch.addmm(homographies[b,:,2], source, homographies[b,:,:2].t()) + source.mul_(1/source[:,2].unsqueeze(1)) + source = source[:,:2].contiguous().view(H,W,2) + warped_sources.append(source) + return torch.stack(warped_sources, dim=0) + + +class PointModel(nn.Module): + def __init__(self, is_test=False): + super(PointModel, self).__init__() + self.is_test = is_test + self.interestpoint_module = InterestPointModule(is_test=self.is_test) + self.correspondence_module = CorrespondenceModule() + self.norm_rgb = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward(self, *args): + img = args[0] + img = self.norm_rgb(img) + score, coord, desc = self.interestpoint_module(img) + return score, coord, desc diff --git a/imcui/third_party/lanet/network_v1/modules.py b/imcui/third_party/lanet/network_v1/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba699e19e1a1f04cd8fdb72b66a4c745ce48107 --- /dev/null +++ b/imcui/third_party/lanet/network_v1/modules.py @@ -0,0 +1,174 @@ +from curses import is_term_resized +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchvision import models +from ..utils import image_grid + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) + +class DilationConv3x3(nn.Module): + def __init__(self, in_channels, out_channels): + super(DilationConv3x3, self).__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=2, dilation=2, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class InterestPointModule(nn.Module): + def __init__(self, is_test=False): + super(InterestPointModule, self).__init__() + self.is_test = is_test + + model = models.vgg16_bn(pretrained=True) + + # use the first 23 layers as encoder + self.encoder = nn.Sequential( + *list(model.features.children())[: 33] + ) + + # score head + self.score_head = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + ) + self.softmax = nn.Softmax(dim=1) + + # location head + self.loc_head = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ) + # location out + self.loc_out = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.shift_out = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1) + + # descriptor out + self.des_out2 = DilationConv3x3(128, 256) + self.des_out3 = DilationConv3x3(256, 256) + self.des_out4 = DilationConv3x3(512, 256) + + def forward(self, x): + B, _, H, W = x.shape + + x = self.encoder[2](self.encoder[1](self.encoder[0](x))) + x = self.encoder[5](self.encoder[4](self.encoder[3](x))) + + x = self.encoder[6](x) + x = self.encoder[9](self.encoder[8](self.encoder[7](x))) + x2 = self.encoder[12](self.encoder[11](self.encoder[10](x))) + + x = self.encoder[13](x2) + x = self.encoder[16](self.encoder[15](self.encoder[14](x))) + x = self.encoder[19](self.encoder[18](self.encoder[17](x))) + x3 = self.encoder[22](self.encoder[21](self.encoder[20](x))) + + x = self.encoder[23](x3) + x = self.encoder[26](self.encoder[25](self.encoder[24](x))) + x = self.encoder[29](self.encoder[28](self.encoder[27](x))) + x = self.encoder[32](self.encoder[31](self.encoder[30](x))) + + + B, _, Hc, Wc = x.shape + + # score head + score_x = self.score_head(x) + aware = self.softmax(score_x[:, 0:3, :, :]) + score = score_x[:, 3, :, :].unsqueeze(1).sigmoid() + + border_mask = torch.ones(B, Hc, Wc) + border_mask[:, 0] = 0 + border_mask[:, Hc - 1] = 0 + border_mask[:, :, 0] = 0 + border_mask[:, :, Wc - 1] = 0 + border_mask = border_mask.unsqueeze(1) + score = score * border_mask.to(score.device) + + # location head + coord_x = self.loc_head(x) + coord_cell = self.loc_out(coord_x).tanh() + + shift_ratio = self.shift_out(coord_x).sigmoid() * 2.0 + + step = ((H/Hc)-1) / 2. + center_base = image_grid(B, Hc, Wc, + dtype=coord_cell.dtype, + device=coord_cell.device, + ones=False, normalized=False).mul(H/Hc) + step + + coord_un = center_base.add(coord_cell.mul(shift_ratio * step)) + coord = coord_un.clone() + coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W-1) + coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H-1) + + # descriptor block + desc_block = [] + desc_block.append(self.des_out2(x2)) + desc_block.append(self.des_out3(x3)) + desc_block.append(self.des_out4(x)) + desc_block.append(aware) + + if self.is_test: + coord_norm = coord[:, :2].clone() + coord_norm[:, 0] = (coord_norm[:, 0] / (float(W-1)/2.)) - 1. + coord_norm[:, 1] = (coord_norm[:, 1] / (float(H-1)/2.)) - 1. + coord_norm = coord_norm.permute(0, 2, 3, 1) + + desc2 = torch.nn.functional.grid_sample(desc_block[0], coord_norm) + desc3 = torch.nn.functional.grid_sample(desc_block[1], coord_norm) + desc4 = torch.nn.functional.grid_sample(desc_block[2], coord_norm) + aware = desc_block[3] + + desc = torch.mul(desc2, aware[:, 0, :, :]) + torch.mul(desc3, aware[:, 1, :, :]) + torch.mul(desc4, aware[:, 2, :, :]) + desc = desc.div(torch.unsqueeze(torch.norm(desc, p=2, dim=1), 1)) # Divide by norm to normalize. + + return score, coord, desc + + return score, coord, desc_block + +class CorrespondenceModule(nn.Module): + def __init__(self, match_type='dual_softmax'): + super(CorrespondenceModule, self).__init__() + self.match_type = match_type + + if self.match_type == 'dual_softmax': + self.temperature = 0.1 + else: + raise NotImplementedError() + + def forward(self, source_desc, target_desc): + b, c, h, w = source_desc.size() + + source_desc = source_desc.div(torch.unsqueeze(torch.norm(source_desc, p=2, dim=1), 1)).view(b, -1, h*w) + target_desc = target_desc.div(torch.unsqueeze(torch.norm(target_desc, p=2, dim=1), 1)).view(b, -1, h*w) + + if self.match_type == 'dual_softmax': + sim_mat = torch.einsum("bcm, bcn -> bmn", source_desc, target_desc) / self.temperature + confidence_matrix = F.softmax(sim_mat, 1) * F.softmax(sim_mat, 2) + else: + raise NotImplementedError() + + return confidence_matrix diff --git a/imcui/third_party/lanet/test.py b/imcui/third_party/lanet/test.py new file mode 100644 index 0000000000000000000000000000000000000000..aac8db788c8a5b5a7613f4b4dcafaed36a5798e0 --- /dev/null +++ b/imcui/third_party/lanet/test.py @@ -0,0 +1,87 @@ +import os +import cv2 +import argparse +import numpy as np +import torch +import torchvision + +from torchvision import datasets, transforms +from torch.autograd import Variable +from network_v0.model import PointModel +from datasets.hp_loader import PatchesDataset +from torch.utils.data import DataLoader +from evaluation.evaluate import evaluate_keypoint_net + + +def main(): + parser = argparse.ArgumentParser(description='Testing') + parser.add_argument('--device', default=0, type=int, help='which gpu to run on.') + parser.add_argument('--test_dir', required=True, type=str, help='Test data path.') + opt = parser.parse_args() + + torch.manual_seed(0) + use_gpu = torch.cuda.is_available() + if use_gpu: + torch.cuda.set_device(opt.device) + + # Load data in 320x240 + hp_dataset_320x240 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type='all') + data_loader_320x240 = DataLoader(hp_dataset_320x240, + batch_size=1, + pin_memory=False, + shuffle=False, + num_workers=4, + worker_init_fn=None, + sampler=None) + + # Load data in 640x480 + hp_dataset_640x480 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type='all') + data_loader_640x480 = DataLoader(hp_dataset_640x480, + batch_size=1, + pin_memory=False, + shuffle=False, + num_workers=4, + worker_init_fn=None, + sampler=None) + + # Load model + model = PointModel(is_test=True) + ckpt = torch.load('./checkpoints/PointModel_v0.pth') + model.load_state_dict(ckpt['model_state']) + model = model.eval() + if use_gpu: + model = model.cuda() + + + print('Evaluating in 320x240, 300 points') + rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( + data_loader_320x240, + model, + output_shape=(320, 240), + top_k=300) + + print('Repeatability: {0:.3f}'.format(rep)) + print('Localization Error: {0:.3f}'.format(loc)) + print('H-1 Accuracy: {:.3f}'.format(c1)) + print('H-3 Accuracy: {:.3f}'.format(c3)) + print('H-5 Accuracy: {:.3f}'.format(c5)) + print('Matching Score: {:.3f}'.format(mscore)) + print('\n') + + print('Evaluating in 640x480, 1000 points') + rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( + data_loader_640x480, + model, + output_shape=(640, 480), + top_k=1000) + + print('Repeatability: {0:.3f}'.format(rep)) + print('Localization Error: {0:.3f}'.format(loc)) + print('H-1 Accuracy: {:.3f}'.format(c1)) + print('H-3 Accuracy: {:.3f}'.format(c3)) + print('H-5 Accuracy: {:.3f}'.format(c5)) + print('Matching Score: {:.3f}'.format(mscore)) + print('\n') + +if __name__ == '__main__': + main() diff --git a/imcui/third_party/lanet/train.py b/imcui/third_party/lanet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..dd506f567cfe071e33c674346ee95f933cd461e8 --- /dev/null +++ b/imcui/third_party/lanet/train.py @@ -0,0 +1,129 @@ +import os +import torch +import torch.optim as optim +from tqdm import tqdm + +from torch.autograd import Variable + +from network_v0.model import PointModel +from loss_function import KeypointLoss + +class Trainer(object): + def __init__(self, config, train_loader=None): + self.config = config + # data parameters + self.train_loader = train_loader + self.num_train = len(self.train_loader) + + # training parameters + self.max_epoch = config.max_epoch + self.start_epoch = config.start_epoch + self.momentum = config.momentum + self.lr = config.init_lr + self.lr_factor = config.lr_factor + self.display = config.display + + # misc params + self.use_gpu = config.use_gpu + self.random_seed = config.seed + self.gpu = config.gpu + self.ckpt_dir = config.ckpt_dir + self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed) + + # build model + self.model = PointModel(is_test=False) + + # training on GPU + if self.use_gpu: + torch.cuda.set_device(self.gpu) + self.model.cuda() + + print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) + + # build loss functional + self.loss_func = KeypointLoss(config) + + # build optimizer and scheduler + self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) + self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor) + + # resume + if int(self.config.start_epoch) > 0: + self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler) + + def train(self): + print("\nTrain on {} samples".format(self.num_train)) + self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) + for epoch in range(self.start_epoch, self.max_epoch): + print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr)) + # train for one epoch + self.train_one_epoch(epoch) + if self.lr_scheduler: + self.lr_scheduler.step() + self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler) + + def train_one_epoch(self, epoch): + self.model.train() + for (i, data) in enumerate(tqdm(self.train_loader)): + + if self.use_gpu: + source_img = data['image_aug'].cuda() + target_img = data['image'].cuda() + homography = data['homography'].cuda() + + source_img = Variable(source_img) + target_img = Variable(target_img) + homography = Variable(homography) + + # forward propogation + output = self.model(source_img, target_img, homography) + + # compute loss + loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) + + # compute gradients and update + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # print training info + msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\ + "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\ + "loss={:.4f} "\ + .format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data) + + if((i % self.display) == 0): + print(msg_batch) + return + + def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): + filename = self.ckpt_name + '_' + str(epoch) + '.pth' + torch.save( + {'epoch': epoch, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict()}, + os.path.join(self.ckpt_dir, filename)) + + def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): + filename = self.ckpt_name + '_' + str(epoch) + '.pth' + ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) + epoch = ckpt['epoch'] + model.load_state_dict(ckpt['model_state']) + optimizer.load_state_dict(ckpt['optimizer_state']) + lr_scheduler.load_state_dict(ckpt['lr_scheduler']) + + print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch'])) + + return epoch, model, optimizer, lr_scheduler + + + + + + + + + + + \ No newline at end of file diff --git a/imcui/third_party/lanet/utils.py b/imcui/third_party/lanet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..416012d2d367739edcbefa22e00b0030f090eede --- /dev/null +++ b/imcui/third_party/lanet/utils.py @@ -0,0 +1,123 @@ +import os +import torch +import numpy as np + +import torchvision.transforms as transforms +from functools import lru_cache + +@lru_cache(maxsize=None) +def meshgrid(B, H, W, dtype, device, normalized=False): + """ + Create mesh-grid given batch size, height and width dimensions. From https://github.com/TRI-ML/KP2D. + + Parameters + ---------- + B: int + Batch size + H: int + Grid Height + W: int + Batch size + dtype: torch.dtype + Tensor dtype + device: str + Tensor device + normalized: bool + Normalized image coordinates or integer-grid. + + Returns + ------- + xs: torch.Tensor + Batched mesh-grid x-coordinates (BHW). + ys: torch.Tensor + Batched mesh-grid y-coordinates (BHW). + """ + if normalized: + xs = torch.linspace(-1, 1, W, device=device, dtype=dtype) + ys = torch.linspace(-1, 1, H, device=device, dtype=dtype) + else: + xs = torch.linspace(0, W-1, W, device=device, dtype=dtype) + ys = torch.linspace(0, H-1, H, device=device, dtype=dtype) + ys, xs = torch.meshgrid([ys, xs]) + return xs.repeat([B, 1, 1]), ys.repeat([B, 1, 1]) + + +@lru_cache(maxsize=None) +def image_grid(B, H, W, dtype, device, ones=True, normalized=False): + """ + Create an image mesh grid with shape B3HW given image shape BHW. From https://github.com/TRI-ML/KP2D. + + Parameters + ---------- + B: int + Batch size + H: int + Grid Height + W: int + Batch size + dtype: str + Tensor dtype + device: str + Tensor device + ones : bool + Use (x, y, 1) coordinates + normalized: bool + Normalized image coordinates or integer-grid. + + Returns + ------- + grid: torch.Tensor + Mesh-grid for the corresponding image shape (B3HW) + """ + xs, ys = meshgrid(B, H, W, dtype, device, normalized=normalized) + coords = [xs, ys] + if ones: + coords.append(torch.ones_like(xs)) # BHW + grid = torch.stack(coords, dim=1) # B3HW + return grid + +def to_tensor_sample(sample, tensor_type='torch.FloatTensor'): + """ + Casts the keys of sample to tensors. From https://github.com/TRI-ML/KP2D. + + Parameters + ---------- + sample : dict + Input sample + tensor_type : str + Type of tensor we are casting to + + Returns + ------- + sample : dict + Sample with keys cast as tensors + """ + transform = transforms.ToTensor() + sample['image'] = transform(sample['image']).type(tensor_type) + return sample + +def warp_keypoints(keypoints, H): + """Warp keypoints given a homography + + Parameters + ---------- + keypoints: numpy.ndarray (N,2) + Keypoint vector. + H: numpy.ndarray (3,3) + Homography. + + Returns + ------- + warped_keypoints: numpy.ndarray (N,2) + Warped keypoints vector. + """ + num_points = keypoints.shape[0] + homogeneous_points = np.concatenate([keypoints, np.ones((num_points, 1))], axis=1) + warped_points = np.dot(homogeneous_points, np.transpose(H)) + return warped_points[:, :2] / warped_points[:, 2:] + +def prepare_dirs(config): + for path in [config.ckpt_dir]: + if not os.path.exists(path): + os.makedirs(path) + diff --git a/imcui/third_party/mast3r/demo.py b/imcui/third_party/mast3r/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee5ee1030af1214f6204af9826de5e22a53ecfa --- /dev/null +++ b/imcui/third_party/mast3r/demo.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# gradio demo executable +# -------------------------------------------------------- +import os +import torch +import tempfile +from contextlib import nullcontext + +from mast3r.demo import get_args_parser, main_demo + +from mast3r.model import AsymmetricMASt3R +from mast3r.utils.misc import hash_md5 + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.demo import set_print_with_timestamp + +import matplotlib.pyplot as pl +pl.ion() + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + set_print_with_timestamp() + + if args.server_name is not None: + server_name = args.server_name + else: + server_name = '0.0.0.0' if args.local_network else '127.0.0.1' + + if args.weights is not None: + weights_path = args.weights + else: + weights_path = "naver/" + args.model_name + + model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) + chkpt_tag = hash_md5(weights_path) + + def get_context(tmp_dir): + return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \ + else nullcontext(tmp_dir) + with get_context(args.tmp_dir) as tmpdirname: + cache_path = os.path.join(tmpdirname, chkpt_tag) + os.makedirs(cache_path, exist_ok=True) + main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent, + share=args.share, gradio_delete_cache=args.gradio_delete_cache) diff --git a/imcui/third_party/mast3r/demo_dust3r_ga.py b/imcui/third_party/mast3r/demo_dust3r_ga.py new file mode 100644 index 0000000000000000000000000000000000000000..361c10e392e42525d57765b3f95fec43a89035a3 --- /dev/null +++ b/imcui/third_party/mast3r/demo_dust3r_ga.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# mast3r gradio demo executable +# -------------------------------------------------------- +import os +import torch +import tempfile + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.model import AsymmetricCroCo3DStereo +from mast3r.model import AsymmetricMASt3R +from dust3r.demo import get_args_parser as dust3r_get_args_parser +from dust3r.demo import main_demo, set_print_with_timestamp + +import matplotlib.pyplot as pl +pl.ion() + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + + +def get_args_parser(): + parser = dust3r_get_args_parser() + + actions = parser._actions + for action in actions: + if action.dest == 'model_name': + action.choices.append('MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric') + # change defaults + parser.prog = 'mast3r demo' + return parser + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + set_print_with_timestamp() + + if args.tmp_dir is not None: + tmp_path = args.tmp_dir + os.makedirs(tmp_path, exist_ok=True) + tempfile.tempdir = tmp_path + + if args.server_name is not None: + server_name = args.server_name + else: + server_name = '0.0.0.0' if args.local_network else '127.0.0.1' + + if args.weights is not None: + weights_path = args.weights + else: + weights_path = "naver/" + args.model_name + + try: + model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) + except Exception as e: + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + # dust3r will write the 3D model inside tmpdirname + with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: + if not args.silent: + print('Outputing stuff in', tmpdirname) + main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent) diff --git a/imcui/third_party/mast3r/docker/docker-compose-cpu.yml b/imcui/third_party/mast3r/docker/docker-compose-cpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..746fe20a790cf609f467a8eba0ae1461669fa5f6 --- /dev/null +++ b/imcui/third_party/mast3r/docker/docker-compose-cpu.yml @@ -0,0 +1,16 @@ +version: '3.8' +services: + mast3r-demo: + build: + context: ./files + dockerfile: cpu.Dockerfile + ports: + - "7860:7860" + volumes: + - ./files/checkpoints:/mast3r/checkpoints + environment: + - DEVICE=cpu + - MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth} + cap_add: + - IPC_LOCK + - SYS_RESOURCE diff --git a/imcui/third_party/mast3r/docker/docker-compose-cuda.yml b/imcui/third_party/mast3r/docker/docker-compose-cuda.yml new file mode 100644 index 0000000000000000000000000000000000000000..30670bd837c09ecd3f8546e640eca87119784769 --- /dev/null +++ b/imcui/third_party/mast3r/docker/docker-compose-cuda.yml @@ -0,0 +1,23 @@ +version: '3.8' +services: + mast3r-demo: + build: + context: ./files + dockerfile: cuda.Dockerfile + ports: + - "7860:7860" + environment: + - DEVICE=cuda + - MODEL=${MODEL:-MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth} + volumes: + - ./files/checkpoints:/mast3r/checkpoints + cap_add: + - IPC_LOCK + - SYS_RESOURCE + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/__init__.py b/imcui/third_party/mast3r/dust3r/croco/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/crops/extract_crops_from_images.py b/imcui/third_party/mast3r/dust3r/croco/datasets/crops/extract_crops_from_images.py new file mode 100644 index 0000000000000000000000000000000000000000..eb66a0474ce44b54c44c08887cbafdb045b11ff3 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/crops/extract_crops_from_images.py @@ -0,0 +1,159 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Extracting crops for pre-training +# -------------------------------------------------------- + +import os +import argparse +from tqdm import tqdm +from PIL import Image +import functools +from multiprocessing import Pool +import math + + +def arg_parser(): + parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list') + + parser.add_argument('--crops', type=str, required=True, help='crop file') + parser.add_argument('--root-dir', type=str, required=True, help='root directory') + parser.add_argument('--output-dir', type=str, required=True, help='output directory') + parser.add_argument('--imsize', type=int, default=256, help='size of the crops') + parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads') + parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories') + parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir') + return parser + + +def main(args): + listing_path = os.path.join(args.output_dir, 'listing.txt') + + print(f'Loading list of crops ... ({args.nthread} threads)') + crops, num_crops_to_generate = load_crop_file(args.crops) + + print(f'Preparing jobs ({len(crops)} candidate image pairs)...') + num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels) + num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels)) + + jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) + del crops + + os.makedirs(args.output_dir, exist_ok=True) + mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map + call = functools.partial(save_image_crops, args) + + print(f"Generating cropped images to {args.output_dir} ...") + with open(listing_path, 'w') as listing: + listing.write('# pair_path\n') + for results in tqdm(mmap(call, jobs), total=len(jobs)): + for path in results: + listing.write(f'{path}\n') + print('Finished writing listing to', listing_path) + + +def load_crop_file(path): + data = open(path).read().splitlines() + pairs = [] + num_crops_to_generate = 0 + for line in tqdm(data): + if line.startswith('#'): + continue + line = line.split(', ') + if len(line) < 8: + img1, img2, rotation = line + pairs.append((img1, img2, int(rotation), [])) + else: + l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) + rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) + pairs[-1][-1].append((rect1, rect2)) + num_crops_to_generate += 1 + return pairs, num_crops_to_generate + + +def prepare_jobs(pairs, num_levels, num_pairs_in_dir): + jobs = [] + powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))] + + def get_path(idx): + idx_array = [] + d = idx + for level in range(num_levels - 1): + idx_array.append(idx // powers[level]) + idx = idx % powers[level] + idx_array.append(d) + return '/'.join(map(lambda x: hex(x)[2:], idx_array)) + + idx = 0 + for pair_data in tqdm(pairs): + img1, img2, rotation, crops = pair_data + if -60 <= rotation and rotation <= 60: + rotation = 0 # most likely not a true rotation + paths = [get_path(idx + k) for k in range(len(crops))] + idx += len(crops) + jobs.append(((img1, img2), rotation, crops, paths)) + return jobs + + +def load_image(path): + try: + return Image.open(path).convert('RGB') + except Exception as e: + print('skipping', path, e) + raise OSError() + + +def save_image_crops(args, data): + # load images + img_pair, rot, crops, paths = data + try: + img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair] + except OSError as e: + return [] + + def area(sz): + return sz[0] * sz[1] + + tgt_size = (args.imsize, args.imsize) + + def prepare_crop(img, rect, rot=0): + # actual crop + img = img.crop(rect) + + # resize to desired size + interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC + img = img.resize(tgt_size, resample=interp) + + # rotate the image + rot90 = (round(rot/90) % 4) * 90 + if rot90 == 90: + img = img.transpose(Image.Transpose.ROTATE_90) + elif rot90 == 180: + img = img.transpose(Image.Transpose.ROTATE_180) + elif rot90 == 270: + img = img.transpose(Image.Transpose.ROTATE_270) + return img + + results = [] + for (rect1, rect2), path in zip(crops, paths): + crop1 = prepare_crop(img1, rect1) + crop2 = prepare_crop(img2, rect2, rot) + + fullpath1 = os.path.join(args.output_dir, path+'_1.jpg') + fullpath2 = os.path.join(args.output_dir, path+'_2.jpg') + os.makedirs(os.path.dirname(fullpath1), exist_ok=True) + + assert not os.path.isfile(fullpath1), fullpath1 + assert not os.path.isfile(fullpath2), fullpath2 + crop1.save(fullpath1) + crop2.save(fullpath2) + results.append(path) + + return results + + +if __name__ == '__main__': + args = arg_parser().parse_args() + main(args) + diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/__init__.py b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe0d399084359495250dc8184671ff498adfbf2 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py @@ -0,0 +1,92 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script to generate image pairs for a given scene reproducing poses provided in a metadata file. +""" +import os +from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator +from datasets.habitat_sim.paths import SCENES_DATASET +import argparse +import quaternion +import PIL.Image +import cv2 +import json +from tqdm import tqdm + +def generate_multiview_images_from_metadata(metadata_filename, + output_dir, + overload_params = dict(), + scene_datasets_paths=None, + exist_ok=False): + """ + Generate images from a metadata file for reproducibility purposes. + """ + # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label + if scene_datasets_paths is not None: + scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True)) + + with open(metadata_filename, 'r') as f: + input_metadata = json.load(f) + metadata = dict() + for key, value in input_metadata.items(): + # Optionally replace some paths + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + if scene_datasets_paths is not None: + for dataset_label, dataset_path in scene_datasets_paths.items(): + if value.startswith(dataset_label): + value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label))) + break + metadata[key] = value + + # Overload some parameters + for key, value in overload_params.items(): + metadata[key] = value + + generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))]) + generate_depth = metadata["generate_depth"] + + os.makedirs(output_dir, exist_ok=exist_ok) + + generator = MultiviewHabitatSimGenerator(**generation_entries) + + # Generate views + for idx_label, data in tqdm(metadata['multiviews'].items()): + positions = data["positions"] + orientations = data["orientations"] + n = len(positions) + for oidx in range(n): + observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx])) + observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 + # Color image saved using PIL + img = PIL.Image.fromarray(observation['color'][:,:,:3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr") + cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json") + with open(filename, "w") as f: + json.dump(camera_params, f) + # Save metadata + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(metadata, f) + + generator.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_filename", required=True) + parser.add_argument("--output_dir", required=True) + args = parser.parse_args() + + generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename, + output_dir=args.output_dir, + scene_datasets_paths=SCENES_DATASET, + overload_params=dict(), + exist_ok=True) + + \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..962ef849d8c31397b8622df4f2d9140175d78873 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py @@ -0,0 +1,27 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script generating commandlines to generate image pairs from metadata files. +""" +import os +import glob +from tqdm import tqdm +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.") + args = parser.parse_args() + + input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True) + + for metadata_filename in tqdm(input_metadata_filenames): + output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir)) + # Do not process the scene if the metadata file already exists + if os.path.exists(os.path.join(output_dir, "metadata.json")): + continue + commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}" + print(commandline) diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py new file mode 100644 index 0000000000000000000000000000000000000000..421d49a1696474415940493296b3f2d982398850 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py @@ -0,0 +1,177 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +from tqdm import tqdm +import argparse +import PIL.Image +import numpy as np +import json +from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError +from datasets.habitat_sim.paths import list_scenes_available +import cv2 +import quaternion +import shutil + +def generate_multiview_images_for_scene(scene_dataset_config_file, + scene, + navmesh, + output_dir, + views_count, + size, + exist_ok=False, + generate_depth=False, + **kwargs): + """ + Generate tuples of overlapping views for a given scene. + generate_depth: generate depth images and camera parameters. + """ + if os.path.exists(output_dir) and not exist_ok: + print(f"Scene {scene}: data already generated. Ignoring generation.") + return + try: + print(f"Scene {scene}: {size} multiview acquisitions to generate...") + os.makedirs(output_dir, exist_ok=exist_ok) + + metadata_filename = os.path.join(output_dir, "metadata.json") + + metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count=views_count, + size=size, + generate_depth=generate_depth, + **kwargs) + metadata_template["multiviews"] = dict() + + if os.path.exists(metadata_filename): + print("Metadata file already exists:", metadata_filename) + print("Loading already generated metadata file...") + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + for key in metadata_template.keys(): + if key != "multiviews": + assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}." + else: + print("No temporary file found. Starting generation from scratch...") + metadata = metadata_template + + starting_id = len(metadata["multiviews"]) + print(f"Starting generation from index {starting_id}/{size}...") + if starting_id >= size: + print("Generation already done.") + return + + generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count = views_count, + size = size, + **kwargs) + + for idx in tqdm(range(starting_id, size)): + # Generate / re-generate the observations + try: + data = generator[idx] + observations = data["observations"] + positions = data["positions"] + orientations = data["orientations"] + + idx_label = f"{idx:08}" + for oidx, observation in enumerate(observations): + observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 + # Color image saved using PIL + img = PIL.Image.fromarray(observation['color'][:,:,:3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr") + cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json") + with open(filename, "w") as f: + json.dump(camera_params, f) + metadata["multiviews"][idx_label] = {"positions": positions.tolist(), + "orientations": orientations.tolist(), + "covisibility_ratios": data["covisibility_ratios"].tolist(), + "valid_fractions": data["valid_fractions"].tolist(), + "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()} + except RecursionError: + print("Recursion error: unable to sample observations for this scene. We will stop there.") + break + + # Regularly save a temporary metadata file, in case we need to restart the generation + if idx % 10 == 0: + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + # Save metadata + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + generator.close() + except NoNaviguableSpaceError: + pass + +def create_commandline(scene_data, generate_depth, exist_ok=False): + """ + Create a commandline string to generate a scene. + """ + def my_formatting(val): + if val is None or val == "": + return '""' + else: + return val + commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)} + --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)} + --navmesh {my_formatting(scene_data.navmesh)} + --output_dir {my_formatting(scene_data.output_dir)} + --generate_depth {int(generate_depth)} + --exist_ok {int(exist_ok)} + """ + commandline = " ".join(commandline.split()) + return commandline + +if __name__ == "__main__": + os.umask(2) + + parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available: + > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands + """) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true") + parser.add_argument("--scene", type=str, default="") + parser.add_argument("--scene_dataset_config_file", type=str, default="") + parser.add_argument("--navmesh", type=str, default="") + + parser.add_argument("--generate_depth", type=int, default=1) + parser.add_argument("--exist_ok", type=int, default=0) + + kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000) + + args = parser.parse_args() + generate_depth=bool(args.generate_depth) + exist_ok = bool(args.exist_ok) + + if args.list_commands: + # Listing scenes available... + scenes_data = list_scenes_available(base_output_dir=args.output_dir) + + for scene_data in scenes_data: + print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok)) + else: + if args.scene == "" or args.output_dir == "": + print("Missing scene or output dir argument!") + print(parser.format_help()) + else: + generate_multiview_images_for_scene(scene=args.scene, + scene_dataset_config_file = args.scene_dataset_config_file, + navmesh = args.navmesh, + output_dir = args.output_dir, + exist_ok=exist_ok, + generate_depth=generate_depth, + **kwargs) \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..91e5f923b836a645caf5d8e4aacc425047e3c144 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py @@ -0,0 +1,390 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +import numpy as np +import quaternion +import habitat_sim +import json +from sklearn.neighbors import NearestNeighbors +import cv2 + +# OpenCV to habitat camera convention transformation +R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0) +R_HABITAT2OPENCV = R_OPENCV2HABITAT.T +DEG2RAD = np.pi / 180 + +def compute_camera_intrinsics(height, width, hfov): + f = width/2 / np.tan(hfov/2 * np.pi/180) + cu, cv = width/2, height/2 + return f, cu, cv + +def compute_camera_pose_opencv_convention(camera_position, camera_orientation): + R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT + t_cam2world = np.asarray(camera_position) + return R_cam2world, t_cam2world + +def compute_pointmap(depthmap, hfov): + """ Compute a HxWx3 pointmap in camera frame from a HxW depth map.""" + height, width = depthmap.shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + # Cast depth map to point + z_cam = depthmap + u, v = np.meshgrid(range(width), range(height)) + x_cam = (u - cu) / f * z_cam + y_cam = (v - cv) / f * z_cam + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1) + return X_cam + +def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation): + """Return a 3D point cloud corresponding to valid pixels of the depth map""" + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation) + + X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov) + valid_mask = (X_cam[:,:,2] != 0.0) + + X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()] + X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3) + return X_world + +def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False): + """ + Compute 'overlapping' metrics based on a distance threshold between two point clouds. + """ + nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2) + distances, indices = nbrs.kneighbors(pointcloud1) + intersection1 = np.count_nonzero(distances.flatten() < distance_threshold) + + data = {"intersection1": intersection1, + "size1": len(pointcloud1)} + if compute_symmetric: + nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1) + distances, indices = nbrs.kneighbors(pointcloud2) + intersection2 = np.count_nonzero(distances.flatten() < distance_threshold) + data["intersection2"] = intersection2 + data["size2"] = len(pointcloud2) + + return data + +def _append_camera_parameters(observation, hfov, camera_location, camera_rotation): + """ + Add camera parameters to the observation dictionnary produced by Habitat-Sim + In-place modifications. + """ + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation) + height, width = observation['depth'].shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + K = np.asarray([[f, 0, cu], + [0, f, cv], + [0, 0, 1.0]]) + observation["camera_intrinsics"] = K + observation["t_cam2world"] = t_cam2world + observation["R_cam2world"] = R_cam2world + +def look_at(eye, center, up, return_cam2world=True): + """ + Return camera pose looking at a given center point. + Analogous of gluLookAt function, using OpenCV camera convention. + """ + z = center - eye + z /= np.linalg.norm(z, axis=-1, keepdims=True) + y = -up + y = y - np.sum(y * z, axis=-1, keepdims=True) * z + y /= np.linalg.norm(y, axis=-1, keepdims=True) + x = np.cross(y, z, axis=-1) + + if return_cam2world: + R = np.stack((x, y, z), axis=-1) + t = eye + else: + # World to camera transformation + # Transposed matrix + R = np.stack((x, y, z), axis=-2) + t = - np.einsum('...ij, ...j', R, eye) + return R, t + +def look_at_for_habitat(eye, center, up, return_cam2world=True): + R, t = look_at(eye, center, up) + orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T) + return orientation, t + +def generate_orientation_noise(pan_range, tilt_range, roll_range): + return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP) + * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT) + * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT)) + + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + +class MultiviewHabitatSimGenerator: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + resolution = (240, 320), + views_count=2, + hfov = 60, + gpu_id = 0, + size = 10000, + minimum_covisibility = 0.5, + transform = None): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.resolution = resolution + self.views_count = views_count + assert(self.views_count >= 1) + self.hfov = hfov + self.gpu_id = gpu_id + self.size = size + self.transform = transform + + # Noise added to camera orientation + self.pan_range = (-3, 3) + self.tilt_range = (-10, 10) + self.roll_range = (-5, 5) + + # Height range to sample cameras + self.height_range = (1.2, 1.8) + + # Random steps between the camera views + self.random_steps_count = 5 + self.random_step_variance = 2.0 + + # Minimum fraction of the scene which should be valid (well defined depth) + self.minimum_valid_fraction = 0.7 + + # Distance threshold to see to select pairs + self.distance_threshold = 0.05 + # Minimum IoU of a view point cloud with respect to the reference view to be kept. + self.minimum_covisibility = minimum_covisibility + + # Maximum number of retries. + self.max_attempts_count = 100 + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly + if self.seed == None: + # Re-seed numpy generator + np.random.seed() + self.seed = np.random.randint(2**32-1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "": + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = "depth" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.resolution + depth_sensor_spec.hfov = self.hfov + depth_sensor_spec.position = [0.0, 0.0, 0] + depth_sensor_spec.orientation + + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = "color" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.resolution + rgb_sensor_spec.hfov = self.hfov + rgb_sensor_spec.position = [0.0, 0.0, 0] + agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec]) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + # Use pre-computed navmesh when available (usually better than those generated automatically) + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + if not self.sim.pathfinder.is_loaded: + # Try to compute a navmesh + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + # Ensure that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})") + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + self.sim.close() + + def __del__(self): + self.sim.close() + + def __len__(self): + return self.size + + def sample_random_viewpoint(self): + """ Sample a random viewpoint using the navmesh """ + nav_point = self.sim.pathfinder.get_random_navigable_point() + + # Sample a random viewpoint height + viewpoint_height = np.random.uniform(*self.height_range) + viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP + viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return viewpoint_position, viewpoint_orientation, nav_point + + def sample_other_random_viewpoint(self, observed_point, nav_point): + """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point.""" + other_nav_point = nav_point + + walk_directions = self.random_step_variance * np.asarray([1,0,1]) + for i in range(self.random_steps_count): + temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3)) + # Snapping may return nan when it fails + if not np.isnan(temp[0]): + other_nav_point = temp + + other_viewpoint_height = np.random.uniform(*self.height_range) + other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP + + # Set viewing direction towards the central point + rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True) + rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return position, rotation, other_nav_point + + def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud): + """ Check if a viewpoint is valid and overlaps significantly with a reference one. """ + # Observation + pixels_count = self.resolution[0] * self.resolution[1] + valid_fraction = len(other_pointcloud) / pixels_count + assert valid_fraction <= 1.0 and valid_fraction >= 0.0 + overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True) + covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count) + is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility) + return is_valid, valid_fraction, covisibility + + def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation): + """ Check if a viewpoint is valid and overlaps significantly with a reference one. """ + # Observation + other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation) + return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + + def render_viewpoint(self, viewpoint_position, viewpoint_orientation): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation) + return viewpoint_observations + + def __getitem__(self, useless_idx): + ref_position, ref_orientation, nav_point = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + # Extract point cloud + ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov, + camera_position=ref_position, camera_rotation=ref_orientation) + + pixels_count = self.resolution[0] * self.resolution[1] + ref_valid_fraction = len(ref_pointcloud) / pixels_count + assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0 + if ref_valid_fraction < self.minimum_valid_fraction: + # This should produce a recursion error at some point when something is very wrong. + return self[0] + # Pick an reference observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + + # Add the first image as reference + viewpoints_observations = [ref_observations] + viewpoints_covisibility = [ref_valid_fraction] + viewpoints_positions = [ref_position] + viewpoints_orientations = [quaternion.as_float_array(ref_orientation)] + viewpoints_clouds = [ref_pointcloud] + viewpoints_valid_fractions = [ref_valid_fraction] + + for _ in range(self.views_count - 1): + # Generate an other viewpoint using some dummy random walk + successful_sampling = False + for sampling_attempt in range(self.max_attempts_count): + position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point) + # Observation + other_viewpoint_observations = self.render_viewpoint(position, rotation) + other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation) + + is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + if is_valid: + successful_sampling = True + break + if not successful_sampling: + print("WARNING: Maximum number of attempts reached.") + # Dirty hack, try using a novel original viewpoint + return self[0] + viewpoints_observations.append(other_viewpoint_observations) + viewpoints_covisibility.append(covisibility) + viewpoints_positions.append(position) + viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding. + viewpoints_clouds.append(other_pointcloud) + viewpoints_valid_fractions.append(valid_fraction) + + # Estimate relations between all pairs of images + pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations))) + for i in range(len(viewpoints_observations)): + pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i] + for j in range(i+1, len(viewpoints_observations)): + overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True) + pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count + pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count + + # IoU is relative to the image 0 + data = {"observations": viewpoints_observations, + "positions": np.asarray(viewpoints_positions), + "orientations": np.asarray(viewpoints_orientations), + "covisibility_ratios": np.asarray(viewpoints_covisibility), + "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float), + "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float), + } + + if self.transform is not None: + data = self.transform(data) + return data + + def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False): + """ + Return a list of images corresponding to a spiral trajectory from a random starting point. + Useful to generate nice visualisations. + Use an even number of half turns to get a nice "C1-continuous" loop effect + """ + ref_position, ref_orientation, navpoint = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov, + camera_position=ref_position, camera_rotation=ref_orientation) + pixels_count = self.resolution[0] * self.resolution[1] + if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction: + # Dirty hack: ensure that the valid part of the image is significant + return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation) + + # Pick an observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation) + + images = [] + is_valid = [] + # Spiral trajectory, use_constant orientation + for i, alpha in enumerate(np.linspace(0, 1, images_count)): + r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius + theta = alpha * half_turns * np.pi + x = r * np.cos(theta) + y = r * np.sin(theta) + z = 0.0 + position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten() + if use_constant_orientation: + orientation = ref_orientation + else: + # trajectory looking at a mean point in front of the ref observation + orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP) + observations = self.render_viewpoint(position, orientation) + images.append(observations['color'][...,:3]) + _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation) + is_valid.append(_is_valid) + return images, np.all(is_valid) \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..10672a01f7dd615d3b4df37781f7f6f97e753ba6 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py @@ -0,0 +1,69 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +""" +Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere. +""" +import os +import glob +from tqdm import tqdm +import shutil +import json +from datasets.habitat_sim.paths import * +import argparse +import collections + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("output_dir") + args = parser.parse_args() + + input_dirname = args.input_dir + output_dirname = args.output_dir + + input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True) + + images_count = collections.defaultdict(lambda : 0) + + os.makedirs(output_dirname) + for input_filename in tqdm(input_metadata_filenames): + # Ignore empty files + with open(input_filename, "r") as f: + original_metadata = json.load(f) + if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0: + print("No views in", input_filename) + continue + + relpath = os.path.relpath(input_filename, input_dirname) + print(relpath) + + # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability. + # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern. + scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True)) + metadata = dict() + for key, value in original_metadata.items(): + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + known_path = False + for dataset, dataset_path in scenes_dataset_paths.items(): + if value.startswith(dataset_path): + value = os.path.join(dataset, os.path.relpath(value, dataset_path)) + known_path = True + break + if not known_path: + raise KeyError("Unknown path:" + value) + metadata[key] = value + + # Compile some general statistics while packing data + scene_split = metadata["scene"].split("/") + upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0] + images_count[upper_level] += len(metadata["multiviews"]) + + output_filename = os.path.join(output_dirname, relpath) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + with open(output_filename, "w") as f: + json.dump(metadata, f) + + # Print statistics + print("Images count:") + for upper_level, count in images_count.items(): + print(f"- {upper_level}: {count}") \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/paths.py b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..4d63b5fa29c274ddfeae084734a35ba66d7edee8 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/habitat_sim/paths.py @@ -0,0 +1,129 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Paths to Habitat-Sim scenes +""" + +import os +import json +import collections +from tqdm import tqdm + + +# Hardcoded path to the different scene datasets +SCENES_DATASET = { + "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/", + "gibson": "./data/habitat-sim-data/scene_datasets/gibson/", + "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/", + "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/", + "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/", + "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/", + "scannet": "./data/habitat-sim/scene_datasets/scannet/" +} + +SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"]) + +def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]): + scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json") + scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"] + navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx]) + # Add scene + data = SceneData(scene_dataset_config_file=scene_dataset_config_file, + scene = scenes[idx] + ".scene_instance.json", + navmesh = os.path.join(base_path, navmeshes[idx]), + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + +def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]): + scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json") + scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], []) + navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx]) + data = SceneData(scene_dataset_config_file=scene_dataset_config_file, + scene = scenes[idx], + navmesh = "", + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + +def list_replica_scenes(base_output_dir, base_path): + scenes_data = [] + for scene_id in os.listdir(base_path): + scene = os.path.join(base_path, scene_id, "mesh.ply") + navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it + scene_dataset_config_file = "" + output_dir = os.path.join(base_output_dir, scene_id) + # Add scene only if it does not exist already, or if exist_ok + data = SceneData(scene_dataset_config_file = scene_dataset_config_file, + scene = scene, + navmesh = navmesh, + output_dir = output_dir) + scenes_data.append(data) + return scenes_data + + +def list_scenes(base_output_dir, base_path): + """ + Generic method iterating through a base_path folder to find scenes. + """ + scenes_data = [] + for root, dirs, files in os.walk(base_path, followlinks=True): + folder_scenes_data = [] + for file in files: + name, ext = os.path.splitext(file) + if ext == ".glb": + scene = os.path.join(root, name + ".glb") + navmesh = os.path.join(root, name + ".navmesh") + if not os.path.exists(navmesh): + navmesh = "" + relpath = os.path.relpath(root, base_path) + output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name)) + data = SceneData(scene_dataset_config_file="", + scene = scene, + navmesh = navmesh, + output_dir = output_dir) + folder_scenes_data.append(data) + + # Specific check for HM3D: + # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version. + basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")] + if len(basis_scenes) != 0: + folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)] + + scenes_data.extend(folder_scenes_data) + return scenes_data + +def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET): + scenes_data = [] + + # HM3D + for split in ("minival", "train", "val", "examples"): + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"), + base_path=f"{scenes_dataset_paths['hm3d']}/{split}") + + # Gibson + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"), + base_path=scenes_dataset_paths["gibson"]) + + # Habitat test scenes (just a few) + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"), + base_path=scenes_dataset_paths["habitat-test-scenes"]) + + # ReplicaCAD (baked lightning) + scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir) + + # ScanNet + scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"), + base_path=scenes_dataset_paths["scannet"]) + + # Replica + list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"), + base_path=scenes_dataset_paths["replica"]) + return scenes_data diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/pairs_dataset.py b/imcui/third_party/mast3r/dust3r/croco/datasets/pairs_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9f107526b34e154d9013a9a7a0bde3d5ff6f581c --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/pairs_dataset.py @@ -0,0 +1,109 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os +from torch.utils.data import Dataset +from PIL import Image + +from datasets.transforms import get_pair_transforms + +def load_image(impath): + return Image.open(impath) + +def load_pairs_from_cache_file(fname, root=''): + assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, 'r') as fid: + lines = fid.read().strip().splitlines() + pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines] + return pairs + +def load_pairs_from_list_file(fname, root=''): + assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, 'r') as fid: + lines = fid.read().strip().splitlines() + pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')] + return pairs + + +def write_cache_file(fname, pairs, root=''): + if len(root)>0: + if not root.endswith('/'): root+='/' + assert os.path.isdir(root) + s = '' + for im1, im2 in pairs: + if len(root)>0: + assert im1.startswith(root), im1 + assert im2.startswith(root), im2 + s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):]) + with open(fname, 'w') as fid: + fid.write(s[:-1]) + +def parse_and_cache_all_pairs(dname, data_dir='./data/'): + if dname=='habitat_release': + dirname = os.path.join(data_dir, 'habitat_release') + assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname + cache_file = os.path.join(dirname, 'pairs.txt') + assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file + + print('Parsing pairs for dataset: '+dname) + pairs = [] + for root, dirs, files in os.walk(dirname): + if 'val' in root: continue + dirs.sort() + pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')] + print('Found {:,} pairs'.format(len(pairs))) + print('Writing cache to: '+cache_file) + write_cache_file(cache_file, pairs, root=dirname) + + else: + raise NotImplementedError('Unknown dataset: '+dname) + +def dnames_to_image_pairs(dnames, data_dir='./data/'): + """ + dnames: list of datasets with image pairs, separated by + + """ + all_pairs = [] + for dname in dnames.split('+'): + if dname=='habitat_release': + dirname = os.path.join(data_dir, 'habitat_release') + assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname + cache_file = os.path.join(dirname, 'pairs.txt') + assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file + pairs = load_pairs_from_cache_file(cache_file, root=dirname) + elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']: + dirname = os.path.join(data_dir, dname+'_crops') + assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) + list_file = os.path.join(dirname, 'listing.txt') + assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file) + pairs = load_pairs_from_list_file(list_file, root=dirname) + print(' {:s}: {:,} pairs'.format(dname, len(pairs))) + all_pairs += pairs + if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs))) + return all_pairs + + +class PairsDataset(Dataset): + + def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'): + super().__init__() + self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) + self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize) + + def __len__(self): + return len(self.image_pairs) + + def __getitem__(self, index): + im1path, im2path = self.image_pairs[index] + im1 = load_image(im1path) + im2 = load_image(im2path) + if self.transforms is not None: im1, im2 = self.transforms(im1, im2) + return im1, im2 + + +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset") + parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") + parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset") + args = parser.parse_args() + parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) diff --git a/imcui/third_party/mast3r/dust3r/croco/datasets/transforms.py b/imcui/third_party/mast3r/dust3r/croco/datasets/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..216bac61f8254fd50e7f269ee80301f250a2d11e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/datasets/transforms.py @@ -0,0 +1,95 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +import torchvision.transforms +import torchvision.transforms.functional as F + +# "Pair": apply a transform on a pair +# "Both": apply the exact same transform to both images + +class ComposePair(torchvision.transforms.Compose): + def __call__(self, img1, img2): + for t in self.transforms: + img1, img2 = t(img1, img2) + return img1, img2 + +class NormalizeBoth(torchvision.transforms.Normalize): + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + +class ToTensorBoth(torchvision.transforms.ToTensor): + def __call__(self, img1, img2): + img1 = super().__call__(img1) + img2 = super().__call__(img2) + return img1, img2 + +class RandomCropPair(torchvision.transforms.RandomCrop): + # the crop will be intentionally different for the two images with this class + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + +class ColorJitterPair(torchvision.transforms.ColorJitter): + # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob + def __init__(self, assymetric_prob, **kwargs): + super().__init__(**kwargs) + self.assymetric_prob = assymetric_prob + def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return img + + def forward(self, img1, img2): + + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) + if torch.rand(1) < self.assymetric_prob: # assymetric: + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) + return img1, img2 + +def get_pair_transforms(transform_str, totensor=True, normalize=True): + # transform_str is eg crop224+color + trfs = [] + for s in transform_str.split('+'): + if s.startswith('crop'): + size = int(s[len('crop'):]) + trfs.append(RandomCropPair(size)) + elif s=='acolor': + trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0)) + elif s=='': # if transform_str was "" + pass + else: + raise NotImplementedError('Unknown augmentation: '+s) + + if totensor: + trfs.append( ToTensorBoth() ) + if normalize: + trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) + + if len(trfs)==0: + return None + elif len(trfs)==1: + return trfs + else: + return ComposePair(trfs) + + + + + diff --git a/imcui/third_party/mast3r/dust3r/croco/demo.py b/imcui/third_party/mast3r/dust3r/croco/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..91b80ccc5c98c18e20d1ce782511aa824ef28f77 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/demo.py @@ -0,0 +1,55 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +from models.croco import CroCoNet +from PIL import Image +import torchvision.transforms +from torchvision.transforms import ToTensor, Normalize, Compose + +def main(): + device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu') + + # load 224x224 images and transform them to tensor + imagenet_mean = [0.485, 0.456, 0.406] + imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True) + imagenet_std = [0.229, 0.224, 0.225] + imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True) + trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)]) + image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) + image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) + + # load model + ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu') + model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device) + model.eval() + msg = model.load_state_dict(ckpt['model'], strict=True) + + # forward + with torch.inference_mode(): + out, mask, target = model(image1, image2) + + # the output is normalized, thus use the mean/std of the actual image to go back to RGB space + patchified = model.patchify(image1) + mean = patchified.mean(dim=-1, keepdim=True) + var = patchified.var(dim=-1, keepdim=True) + decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean) + # undo imagenet normalization, prepare masked image + decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor + input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor + ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor + image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None]) + masked_input_image = ((1 - image_masks) * input_image) + + # make visualization + visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4 + B, C, H, W = visualization.shape + visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W) + visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1)) + fname = "demo_output.png" + visualization.save(fname) + print('Visualization save in '+fname) + + +if __name__=="__main__": + main() diff --git a/imcui/third_party/mast3r/dust3r/croco/models/blocks.py b/imcui/third_party/mast3r/dust3r/croco/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..18133524f0ae265b0bd8d062d7c9eeaa63858a9b --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/blocks.py @@ -0,0 +1,241 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Main encoder/decoder blocks +# -------------------------------------------------------- +# References: +# timm +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py + + +import torch +import torch.nn as nn + +from itertools import repeat +import collections.abc + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class Attention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x, xpos): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3) + q, k, v = [qkv[:,:,i] for i in range(3)] + # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, xpos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class CrossAttention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward(self, query, key, value, qpos, kpos): + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class DecoderBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x, y + + +# patch embedding +class PositionGetter(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos + +class PatchEmbed(nn.Module): + """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.position_getter = PositionGetter() + + def forward(self, x): + B, C, H, W = x.shape + torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + def _init_weights(self): + w = self.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + diff --git a/imcui/third_party/mast3r/dust3r/croco/models/criterion.py b/imcui/third_party/mast3r/dust3r/croco/models/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..11696c40865344490f23796ea45e8fbd5e654731 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/criterion.py @@ -0,0 +1,37 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Criterion to train CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import torch + +class MaskedMSE(torch.nn.Module): + + def __init__(self, norm_pix_loss=False, masked=True): + """ + norm_pix_loss: normalize each patch by their pixel mean and variance + masked: compute loss over the masked patches only + """ + super().__init__() + self.norm_pix_loss = norm_pix_loss + self.masked = masked + + def forward(self, pred, mask, target): + + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + if self.masked: + loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches + else: + loss = loss.mean() # mean loss + return loss diff --git a/imcui/third_party/mast3r/dust3r/croco/models/croco.py b/imcui/third_party/mast3r/dust3r/croco/models/croco.py new file mode 100644 index 0000000000000000000000000000000000000000..14c68634152d75555b4c35c25af268394c5821fe --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/croco.py @@ -0,0 +1,249 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# CroCo model during pretraining +# -------------------------------------------------------- + + + +import torch +import torch.nn as nn +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 +from functools import partial + +from models.blocks import Block, DecoderBlock, PatchEmbed +from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D +from models.masking import RandomMask + + +class CroCoNet(nn.Module): + + def __init__(self, + img_size=224, # input image size + patch_size=16, # patch_size + mask_ratio=0.9, # ratios of masked tokens + enc_embed_dim=768, # encoder feature dimension + enc_depth=12, # encoder depth + enc_num_heads=12, # encoder number of heads in the transformer block + dec_embed_dim=512, # decoder feature dimension + dec_depth=8, # decoder depth + dec_num_heads=16, # decoder number of heads in the transformer block + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder + pos_embed='cosine', # positional embedding (either cosine or RoPE100) + ): + + super(CroCoNet, self).__init__() + + # patch embeddings (with initialization done as in MAE) + self._set_patch_embed(img_size, patch_size, enc_embed_dim) + + # mask generations + self._set_mask_generator(self.patch_embed.num_patches, mask_ratio) + + self.pos_embed = pos_embed + if pos_embed=='cosine': + # positional embedding of the encoder + enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float()) + # positional embedding of the decoder + dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float()) + # pos embedding in each block + self.rope = None # nothing for cosine + elif pos_embed.startswith('RoPE'): # eg RoPE100 + self.enc_pos_embed = None # nothing to add in the encoder with RoPE + self.dec_pos_embed = None # nothing to add in the decoder with RoPE + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(pos_embed[len('RoPE'):]) + self.rope = RoPE2D(freq=freq) + else: + raise NotImplementedError('Unknown pos_embed '+pos_embed) + + # transformer for the encoder + self.enc_depth = enc_depth + self.enc_embed_dim = enc_embed_dim + self.enc_blocks = nn.ModuleList([ + Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope) + for i in range(enc_depth)]) + self.enc_norm = norm_layer(enc_embed_dim) + + # masked tokens + self._set_mask_token(dec_embed_dim) + + # decoder + self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec) + + # prediction head + self._set_prediction_head(dec_embed_dim, patch_size) + + # initializer weights + self.initialize_weights() + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) + + def _set_mask_generator(self, num_patches, mask_ratio): + self.mask_generator = RandomMask(num_patches, mask_ratio) + + def _set_mask_token(self, dec_embed_dim): + self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) + + def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec): + self.dec_depth = dec_depth + self.dec_embed_dim = dec_embed_dim + # transfer from encoder to decoder + self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) + # transformer for the decoder + self.dec_blocks = nn.ModuleList([ + DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope) + for i in range(dec_depth)]) + # final norm layer + self.dec_norm = norm_layer(dec_embed_dim) + + def _set_prediction_head(self, dec_embed_dim, patch_size): + self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) + + + def initialize_weights(self): + # patch embed + self.patch_embed._init_weights() + # mask tokens + if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02) + # linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _encode_image(self, image, do_mask=False, return_all_blocks=False): + """ + image has B x 3 x img_size x img_size + do_mask: whether to perform masking or not + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + """ + # embed the image into patches (x has size B x Npatches x C) + # and get position if each return patch (pos has size B x Npatches x 2) + x, pos = self.patch_embed(image) + # add positional embedding without cls token + if self.enc_pos_embed is not None: + x = x + self.enc_pos_embed[None,...] + # apply masking + B,N,C = x.size() + if do_mask: + masks = self.mask_generator(x) + x = x[~masks].view(B, -1, C) + posvis = pos[~masks].view(B, -1, 2) + else: + B,N,C = x.size() + masks = torch.zeros((B,N), dtype=bool) + posvis = pos + # now apply the transformer encoder and normalization + if return_all_blocks: + out = [] + for blk in self.enc_blocks: + x = blk(x, posvis) + out.append(x) + out[-1] = self.enc_norm(out[-1]) + return out, pos, masks + else: + for blk in self.enc_blocks: + x = blk(x, posvis) + x = self.enc_norm(x) + return x, pos, masks + + def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): + """ + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + + masks1 can be None => assume image1 fully visible + """ + # encoder to decoder layer + visf1 = self.decoder_embed(feat1) + f2 = self.decoder_embed(feat2) + # append masked tokens to the sequence + B,Nenc,C = visf1.size() + if masks1 is None: # downstreams + f1_ = visf1 + else: # pretraining + Ntotal = masks1.size(1) + f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) + f1_[~masks1] = visf1.view(B * Nenc, C) + # add positional embedding + if self.dec_pos_embed is not None: + f1_ = f1_ + self.dec_pos_embed + f2 = f2 + self.dec_pos_embed + # apply Transformer blocks + out = f1_ + out2 = f2 + if return_all_blocks: + _out, out = out, [] + for blk in self.dec_blocks: + _out, out2 = blk(_out, out2, pos1, pos2) + out.append(_out) + out[-1] = self.dec_norm(out[-1]) + else: + for blk in self.dec_blocks: + out, out2 = blk(out, out2, pos1, pos2) + out = self.dec_norm(out) + return out + + def patchify(self, imgs): + """ + imgs: (B, 3, H, W) + x: (B, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x, channels=3): + """ + x: (N, L, patch_size**2 *channels) + imgs: (N, 3, H, W) + """ + patch_size = self.patch_embed.patch_size[0] + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) + return imgs + + def forward(self, img1, img2): + """ + img1: tensor of size B x 3 x img_size x img_size + img2: tensor of size B x 3 x img_size x img_size + + out will be B x N x (3*patch_size*patch_size) + masks are also returned as B x N just in case + """ + # encoder of the masked first image + feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) + # encoder of the second image + feat2, pos2, _ = self._encode_image(img2, do_mask=False) + # decoder + decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) + # prediction head + out = self.prediction_head(decfeat) + # get target + target = self.patchify(img1) + return out, mask1, target diff --git a/imcui/third_party/mast3r/dust3r/croco/models/croco_downstream.py b/imcui/third_party/mast3r/dust3r/croco/models/croco_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..159dfff4d2c1461bc235e21441b57ce1e2088f76 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/croco_downstream.py @@ -0,0 +1,122 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# CroCo model for downstream tasks +# -------------------------------------------------------- + +import torch + +from .croco import CroCoNet + + +def croco_args_from_ckpt(ckpt): + if 'croco_kwargs' in ckpt: # CroCo v2 released models + return ckpt['croco_kwargs'] + elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release + s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" + assert s.startswith('CroCoNet(') + return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it + else: # CroCo v1 released models + return dict() + +class CroCoDownstreamMonocularEncoder(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for monocular downstream task, only using the encoder. + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + NOTE: It works by *calling super().__init__() but with redefined setters + + """ + super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_decoder(self, *args, **kwargs): + """ No decoder """ + return + + def _set_prediction_head(self, *args, **kwargs): + """ No 'prediction head' for downstream tasks.""" + return + + def forward(self, img): + """ + img if of size batch_size x 3 x h x w + """ + B, C, H, W = img.size() + img_info = {'height': H, 'width': W} + need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers) + return self.head(out, img_info) + + +class CroCoDownstreamBinocular(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for binocular downstream task + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + """ + super(CroCoDownstreamBinocular, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head for downstream tasks, define your own head """ + return + + def encode_image_pairs(self, img1, img2, return_all_blocks=False): + """ run encoder for a pair of images + it is actually ~5% faster to concatenate the images along the batch dimension + than to encode them separately + """ + ## the two commented lines below is the naive version with separate encoding + #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) + #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) + ## and now the faster version + out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks ) + if return_all_blocks: + out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) + out2 = out2[-1] + else: + out,out2 = out.chunk(2, dim=0) + pos,pos2 = pos.chunk(2, dim=0) + return out, out2, pos, pos2 + + def forward(self, img1, img2): + B, C, H, W = img1.size() + img_info = {'height': H, 'width': W} + return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks) + if return_all_blocks: + decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks) + decout = out+decout + else: + decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks) + return self.head(decout, img_info) \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/models/curope/__init__.py b/imcui/third_party/mast3r/dust3r/croco/models/curope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25e3d48a162760260826080f6366838e83e26878 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/curope/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from .curope2d import cuRoPE2D diff --git a/imcui/third_party/mast3r/dust3r/croco/models/curope/curope2d.py b/imcui/third_party/mast3r/dust3r/croco/models/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c12f8c529e9a889b5ac20c5767158f238e17d --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/curope/curope2d.py @@ -0,0 +1,40 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func (torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d( tokens, positions, base, F0 ) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d( grad_res, positions, base, -F0 ) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) + return tokens \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/models/curope/setup.py b/imcui/third_party/mast3r/dust3r/croco/models/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..230632ed05e309200e8f93a3a852072333975009 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/curope/setup.py @@ -0,0 +1,34 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ + # '-gencode', 'arch=compute_70,code=sm_70', + # '-gencode', 'arch=compute_75,code=sm_75', + # '-gencode', 'arch=compute_80,code=sm_80', + # '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name = 'curope', + ext_modules = [ + CUDAExtension( + name='curope', + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args = dict( + nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, + cxx=['-O3']) + ) + ], + cmdclass = { + 'build_ext': BuildExtension + }) diff --git a/imcui/third_party/mast3r/dust3r/croco/models/dpt_block.py b/imcui/third_party/mast3r/dust3r/croco/models/dpt_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ddfb74e2769ceca88720d4c730e00afd71c763 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/dpt_block.py @@ -0,0 +1,450 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# DPT head for ViTs +# -------------------------------------------------------- +# References: +# https://github.com/isl-org/DPT +# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from typing import Union, Tuple, Iterable, List, Optional, Dict + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList([ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ]) + + return scratch + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + width_ratio=1, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + self.width_ratio = width_ratio + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + if self.width_ratio != 1: + res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear') + + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if self.width_ratio != 1: + # and output.shape[3] < self.width_ratio * output.shape[2] + #size=(image.shape[]) + if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio: + shape = 3 * output.shape[3] + else: + shape = int(self.width_ratio * 2 * output.shape[2]) + output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear') + else: + output = nn.functional.interpolate(output, scale_factor=2, + mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + return output + +def make_fusion_block(features, use_bn, width_ratio=1): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + width_ratio=width_ratio, + ) + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_cahnnels: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__(self, + num_channels: int = 1, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ('rgb',), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + last_dim: int = 32, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = 'regression', + output_width_ratio=1, + **kwargs): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None + self.head_type = head_type + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + + if self.head_type == 'regression': + # The "DPTDepthModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0) + ) + elif self.head_type == 'semseg': + # The "DPTSegmentationModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc=768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + #print(dim_tokens_enc) + + # Set up activation postprocessing layers + if isinstance(dim_tokens_enc, int): + dim_tokens_enc = 4 * [dim_tokens_enc] + + self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc] + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[0], + out_channels=self.layer_dims[0], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, stride=4, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[1], + out_channels=self.layer_dims[1], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, stride=2, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[2], + out_channels=self.layer_dims[2], + kernel_size=1, stride=1, padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[3], + out_channels=self.layer_dims[3], + kernel_size=1, stride=1, padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, stride=2, padding=1, + ) + ) + + self.act_postprocess = nn.ModuleList([ + self.act_1_postprocess, + self.act_2_postprocess, + self.act_3_postprocess, + self.act_4_postprocess + ]) + + def adapt_tokens(self, encoder_tokens): + # Adapt tokens + x = [] + x.append(encoder_tokens[:, :]) + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], image_size): + #input_info: Dict): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + H, W = image_size + + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out diff --git a/imcui/third_party/mast3r/dust3r/croco/models/head_downstream.py b/imcui/third_party/mast3r/dust3r/croco/models/head_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..bd40c91ba244d6c3522c6efd4ed4d724b7bdc650 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/models/head_downstream.py @@ -0,0 +1,58 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Heads for downstream tasks +# -------------------------------------------------------- + +""" +A head is a module where the __init__ defines only the head hyperparameters. +A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. +The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' +""" + +import torch +import torch.nn as nn +from .dpt_block import DPTOutputAdapter + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for CroCo. + by default, hooks_idx will be equal to: + * for encoder-only: 4 equally spread layers + * for encoder+decoder: last encoder + 3 equally spread layers of the decoder + """ + + def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768], + output_width_ratio=1, num_channels=1, postprocess=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_blocks = True # backbone needs to return all layers + self.postprocess = postprocess + self.output_width_ratio = output_width_ratio + self.num_channels = num_channels + self.hooks_idx = hooks_idx + self.layer_dims = layer_dims + + def setup(self, croconet): + dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels} + if self.hooks_idx is None: + if hasattr(croconet, 'dec_blocks'): # encoder + decoder + step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] + hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + else: # encoder only + step = croconet.enc_depth//4 + hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + self.hooks_idx = hooks_idx + print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}') + dpt_args['hooks'] = self.hooks_idx + dpt_args['layer_dims'] = self.layer_dims + self.dpt = DPTOutputAdapter(**dpt_args) + dim_tokens = [croconet.enc_embed_dim if hook0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +#---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +#---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D,seq_len,device,dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D,seq_len,device,dtype] = (cos,sin) + return self.cache[D,seq_len,device,dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim==2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:,:,0], cos, sin) + x = self.apply_rope1d(x, positions[:,:,1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/pretrain.py b/imcui/third_party/mast3r/dust3r/croco/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..2c45e488015ef5380c71d0381ff453fdb860759e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/pretrain.py @@ -0,0 +1,254 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Pre-training CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time +import math +from pathlib import Path +from typing import Iterable + +import torch +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +import utils.misc as misc +from utils.misc import NativeScalerWithGradNormCount as NativeScaler +from models.croco import CroCoNet +from models.criterion import MaskedMSE +from datasets.pairs_dataset import PairsDataset + + +def get_args_parser(): + parser = argparse.ArgumentParser('CroCo pre-training', add_help=False) + # model and criterion + parser.add_argument('--model', default='CroCoNet()', type=str, help="string containing the model to build") + parser.add_argument('--norm_pix_loss', default=1, choices=[0,1], help="apply per-patch mean/std normalization before applying the loss") + # dataset + parser.add_argument('--dataset', default='habitat_release', type=str, help="training set") + parser.add_argument('--transforms', default='crop224+acolor', type=str, help="transforms to apply") # in the paper, we also use some homography and rotation, but find later that they were not useful or even harmful + # training + parser.add_argument('--seed', default=0, type=int, help="Random seed") + parser.add_argument('--batch_size', default=64, type=int, help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") + parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") + parser.add_argument('--max_epoch', default=400, type=int, help="Stop training at this epoch") + parser.add_argument('--accum_iter', default=1, type=int, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") + parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") + parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') + parser.add_argument('--amp', type=int, default=1, choices=[0,1], help="Use Automatic Mixed Precision for pretraining") + # others + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--save_freq', default=1, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') + parser.add_argument('--keep_freq', default=20, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') + parser.add_argument('--print_freq', default=20, type=int, help='frequence (number of iterations) to print infos while training') + # paths + parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output") + parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") + return parser + + + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + + print("output_dir: "+args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # auto resume + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + ## training dataset and loader + print('Building dataset for {:s} with transforms {:s}'.format(args.dataset, args.transforms)) + dataset = PairsDataset(args.dataset, trfs=args.transforms, data_dir=args.data_dir) + if world_size>1: + sampler_train = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + else: + sampler_train = torch.utils.data.RandomSampler(dataset) + data_loader_train = torch.utils.data.DataLoader( + dataset, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + + ## model + print('Loading model: {:s}'.format(args.model)) + model = eval(args.model) + print('Loading criterion: MaskedMSE(norm_pix_loss={:s})'.format(str(bool(args.norm_pix_loss)))) + criterion = MaskedMSE(norm_pix_loss=bool(args.norm_pix_loss)) + + model.to(device) + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True) + model_without_ddp = model.module + + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) # following timm: set wd as 0 for bias and norm layers + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training until {args.max_epoch} epochs") + start_time = time.time() + for epoch in range(args.start_epoch, args.max_epoch): + if world_size>1: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + log_writer=log_writer, + args=args + ) + + if args.output_dir and epoch % args.save_freq == 0 : + misc.save_model( + args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, fname='last') + + if args.output_dir and (epoch % args.keep_freq == 0 or epoch + 1 == args.max_epoch) and (epoch>0 or args.max_epoch==1): + misc.save_model( + args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch,} + + if args.output_dir and misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + log_writer=None, + args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + accum_iter = args.accum_iter + + optimizer.zero_grad() + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + for data_iter_step, (image1, image2) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + + # we use a per iteration lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + out, mask, target = model(image1, image2) + loss = criterion(out, mask, target) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and ((data_iter_step + 1) % (accum_iter*args.print_freq)) == 0: + # x-axis is based on epoch_1000x in the tensorboard, calibrating differences curves when batch size changes + epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) + log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('lr', lr, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/imcui/third_party/mast3r/dust3r/croco/stereoflow/augmentor.py b/imcui/third_party/mast3r/dust3r/croco/stereoflow/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..69e6117151988d94cbc4b385e0d88e982133bf10 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/stereoflow/augmentor.py @@ -0,0 +1,290 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Data augmentation for training stereo and flow +# -------------------------------------------------------- + +# References +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/stereo/transforms.py +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/transforms.py + + +import numpy as np +import random +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torchvision.transforms.functional as FF + +class StereoAugmentor(object): + + def __init__(self, crop_size, scale_prob=0.5, scale_xonly=True, lhth=800., lminscale=0.0, lmaxscale=1.0, hminscale=-0.2, hmaxscale=0.4, scale_interp_nearest=True, rightjitterprob=0.5, v_flip_prob=0.5, color_aug_asym=True, color_choice_prob=0.5): + self.crop_size = crop_size + self.scale_prob = scale_prob + self.scale_xonly = scale_xonly + self.lhth = lhth + self.lminscale = lminscale + self.lmaxscale = lmaxscale + self.hminscale = hminscale + self.hmaxscale = hmaxscale + self.scale_interp_nearest = scale_interp_nearest + self.rightjitterprob = rightjitterprob + self.v_flip_prob = v_flip_prob + self.color_aug_asym = color_aug_asym + self.color_choice_prob = color_choice_prob + + def _random_scale(self, img1, img2, disp): + ch,cw = self.crop_size + h,w = img1.shape[:2] + if self.scale_prob>0. and np.random.rand()1.: + scale_x = clip_scale + scale_y = scale_x if not self.scale_xonly else 1.0 + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + disp = cv2.resize(disp, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR if not self.scale_interp_nearest else cv2.INTER_NEAREST) * scale_x + return img1, img2, disp + + def _random_crop(self, img1, img2, disp): + h,w = img1.shape[:2] + ch,cw = self.crop_size + assert ch<=h and cw<=w, (img1.shape, h,w,ch,cw) + offset_x = np.random.randint(w - cw + 1) + offset_y = np.random.randint(h - ch + 1) + img1 = img1[offset_y:offset_y+ch,offset_x:offset_x+cw] + img2 = img2[offset_y:offset_y+ch,offset_x:offset_x+cw] + disp = disp[offset_y:offset_y+ch,offset_x:offset_x+cw] + return img1, img2, disp + + def _random_vflip(self, img1, img2, disp): + # vertical flip + if self.v_flip_prob>0 and np.random.rand() < self.v_flip_prob: + img1 = np.copy(np.flipud(img1)) + img2 = np.copy(np.flipud(img2)) + disp = np.copy(np.flipud(disp)) + return img1, img2, disp + + def _random_rotate_shift_right(self, img2): + if self.rightjitterprob>0. and np.random.rand() 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow = np.inf * np.ones([ht1, wd1, 2], dtype=np.float32) # invalid value every where, before we fill it with the correct ones + flow[yy, xx] = flow1 + return flow + + def spatial_transform(self, img1, img2, flow, dname): + + if np.random.rand() < self.spatial_aug_prob: + # randomly sample scale + ht, wd = img1.shape[:2] + clip_min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + min_scale, max_scale = self.min_scale, self.max_scale + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_x = np.clip(scale_x, clip_min_scale, None) + scale_y = np.clip(scale_y, clip_min_scale, None) + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = self._resize_flow(flow, scale_x, scale_y, factor=2.0 if dname=='Spring' else 1.0) + elif dname=="Spring": + flow = self._resize_flow(flow, 1.0, 1.0, factor=2.0) + + if self.h_flip_prob>0. and np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if self.v_flip_prob>0. and np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + # In case no cropping + if img1.shape[0] - self.crop_size[0] > 0: + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + else: + y0 = 0 + if img1.shape[1] - self.crop_size[1] > 0: + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + else: + x0 = 0 + + img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow, dname): + img1, img2, flow = self.spatial_transform(img1, img2, flow, dname) + img1, img2 = self.color_transform(img1, img2) + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + return img1, img2, flow \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/stereoflow/criterion.py b/imcui/third_party/mast3r/dust3r/croco/stereoflow/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..57792ebeeee34827b317a4d32b7445837bb33f17 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/stereoflow/criterion.py @@ -0,0 +1,251 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Losses, metrics per batch, metrics per dataset +# -------------------------------------------------------- + +import torch +from torch import nn +import torch.nn.functional as F + +def _get_gtnorm(gt): + if gt.size(1)==1: # stereo + return gt + # flow + return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW + +############ losses without confidence + +class L1Loss(nn.Module): + + def __init__(self, max_gtnorm=None): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = False + + def _error(self, gt, predictions): + return torch.abs(gt-predictions) + + def forward(self, predictions, gt, inspect=False): + mask = torch.isfinite(gt) + if self.max_gtnorm is not None: + mask *= _get_gtnorm(gt).expand(-1,gt.size(1),-1,-1) which is a constant + + +class LaplacianLossBounded(nn.Module): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b + def __init__(self, max_gtnorm=10000., a=0.25, b=4.): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:,0,:,:] + if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant + +class LaplacianLossBounded2(nn.Module): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b + def __init__(self, max_gtnorm=None, a=3.0, b=3.0): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:,0,:,:] + if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant + +############## metrics per batch + +class StereoMetrics(nn.Module): + + def __init__(self, do_quantile=False): + super().__init__() + self.bad_ths = [0.5,1,2,3] + self.do_quantile = do_quantile + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + gtcopy = gt.clone() + mask = torch.isfinite(gtcopy) + gtcopy[~mask] = 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0 + Npx = mask.view(B,-1).sum(dim=1) + L1error = (torch.abs(gtcopy-predictions)*mask).view(B,-1) + L2error = (torch.square(gtcopy-predictions)*mask).view(B,-1) + # avgerr + metrics['avgerr'] = torch.mean(L1error.sum(dim=1)/Npx ) + # rmse + metrics['rmse'] = torch.sqrt(L2error.sum(dim=1)/Npx).mean(dim=0) + # err > t for t in [0.5,1,2,3] + for ths in self.bad_ths: + metrics['bad@{:.1f}'.format(ths)] = (((L1error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100 + return metrics + +class FlowMetrics(nn.Module): + def __init__(self): + super().__init__() + self.bad_ths = [1,3,5] + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + mask = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + Npx = mask.view(B,-1).sum(dim=1) + gtcopy = gt.clone() # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored + gtcopy[:,0,:,:][~mask] = 999999.0 + gtcopy[:,1,:,:][~mask] = 999999.0 + L1error = (torch.abs(gtcopy-predictions).sum(dim=1)*mask).view(B,-1) + L2error = (torch.sqrt(torch.sum(torch.square(gtcopy-predictions),dim=1))*mask).view(B,-1) + metrics['L1err'] = torch.mean(L1error.sum(dim=1)/Npx ) + metrics['EPE'] = torch.mean(L2error.sum(dim=1)/Npx ) + for ths in self.bad_ths: + metrics['bad@{:.1f}'.format(ths)] = (((L2error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100 + return metrics + +############## metrics per dataset +## we update the average and maintain the number of pixels while adding data batch per batch +## at the beggining, call reset() +## after each batch, call add_batch(...) +## at the end: call get_results() + +class StereoDatasetMetrics(nn.Module): + + def __init__(self): + super().__init__() + self.bad_ths = [0.5,1,2,3] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self._metrics = None + + def add_batch(self, predictions, gt): + assert predictions.size(1)==1, predictions.size() + assert gt.size(1)==1, gt.size() + if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ... + L1err = torch.minimum( torch.minimum( torch.minimum( + torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1), + torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1)) + valid = torch.isfinite(L1err) + else: + valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt-predictions),dim=1) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew + self.agg_N = Nnew + for i,th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L1err[valid]>th).sum().cpu() + + def _compute_metrics(self): + if self._metrics is not None: return + out = {} + out['L1err'] = self.agg_L1err.item() + for i,th in enumerate(self.bad_ths): + out['bad@{:.1f}'.format(th)] = (float(self.agg_Nbad[i]) / self.agg_N).item() * 100.0 + self._metrics = out + + def get_results(self): + self._compute_metrics() # to avoid recompute them multiple times + return self._metrics + +class FlowDatasetMetrics(nn.Module): + + def __init__(self): + super().__init__() + self.bad_ths = [0.5,1,3,5] + self.speed_ths = [(0,10),(10,40),(40,torch.inf)] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self.agg_EPEspeed = [torch.tensor(0.0) for _ in self.speed_ths] # EPE per speed bin so far + self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far + self._metrics = None + self.pairname_results = {} + + def add_batch(self, predictions, gt): + assert predictions.size(1)==2, predictions.size() + assert gt.size(1)==2, gt.size() + if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ... + L1err = torch.minimum( torch.minimum( torch.minimum( + torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1), + torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)), + torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1)) + L2err = torch.minimum( torch.minimum( torch.minimum( + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]-predictions),dim=1)), + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]-predictions),dim=1))), + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]-predictions),dim=1))), + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]-predictions),dim=1))) + valid = torch.isfinite(L1err) + gtspeed = (torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]),dim=1)) +\ + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]),dim=1)) ) / 4.0 # let's just average them + else: + valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt-predictions),dim=1) + L2err = torch.sqrt(torch.sum(torch.square(gt-predictions),dim=1)) + gtspeed = torch.sqrt(torch.sum(torch.square(gt),dim=1)) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew + self.agg_L2err = float(self.agg_N)/Nnew * self.agg_L2err + L2err[valid].mean().cpu() * float(N)/Nnew + self.agg_N = Nnew + for i,th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L2err[valid]>th).sum().cpu() + for i,(th1,th2) in enumerate(self.speed_ths): + vv = (gtspeed[valid]>=th1) * (gtspeed[valid] don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len(self.pairnames) # each pairname is typically of the form (str, int1, int2) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + img1name = self.pairname_to_img1name(pairname) + img2name = self.pairname_to_img2name(pairname) + flowname = self.pairname_to_flowname(pairname) if self.pairname_to_flowname is not None else None + + # load images and disparities + img1 = _read_img(img1name) + img2 = _read_img(img2name) + flow = self.load_flow(flowname) if flowname is not None else None + + # apply augmentations + if self.augmentor is not None: + img1, img2, flow = self.augmentor(img1, img2, flow, self.name) + + if self.totensor: + img1 = img_to_tensor(img1) + img2 = img_to_tensor(img2) + if flow is not None: + flow = flow_to_tensor(flow) + else: + flow = torch.tensor([]) # to allow dataloader batching with default collate_gn + pairname = str(pairname) # transform potential tuple to str to be able to batch it + + return img1, img2, flow, pairname + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f'{self.__class__.__name__}_{self.split}' + + def __repr__(self): + s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' + if self.rmul==1: + s+=f'\n\tnum pairs: {len(self.pairnames)}' + else: + s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name+'.pkl') + if osp.isfile(cache_file): + with open(cache_file, 'rb') as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, 'wb') as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + +class TartanAirDataset(FlowDataset): + + def _prepare_data(self): + self.name = "TartanAir" + self._set_root() + assert self.split in ['train'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[2])) + self.pairname_to_flowname = lambda pairname: osp.join(self.root, pairname[0], 'flow/{:06d}_{:06d}_flow.npy'.format(pairname[1],pairname[2])) + self.pairname_to_str = lambda pairname: os.path.join(pairname[0][pairname[0].find('/')+1:], '{:06d}_{:06d}'.format(pairname[1], pairname[2])) + self.load_flow = _read_numpy_flow + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + pairs = [(osp.join(s,s,difficulty,Pxxx),int(a[:6]),int(a[:6])+1) for s in seqs for difficulty in ['Easy','Hard'] for Pxxx in sorted(os.listdir(osp.join(self.root,s,s,difficulty))) for a in sorted(os.listdir(osp.join(self.root,s,s,difficulty,Pxxx,'image_left/')))[:-1]] + assert len(pairs)==306268, "incorrect parsing of pairs in TartanAir" + tosave = {'train': pairs} + return tosave + +class FlyingChairsDataset(FlowDataset): + + def _prepare_data(self): + self.name = "FlyingChairs" + self._set_root() + assert self.split in ['train','val'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, 'data', pairname+'_img1.ppm') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, 'data', pairname+'_img2.ppm') + self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'data', pairname+'_flow.flo') + self.pairname_to_str = lambda pairname: pairname + self.load_flow = _read_flo_file + + def _build_cache(self): + split_file = osp.join(self.root, 'chairs_split.txt') + split_list = np.loadtxt(split_file, dtype=np.int32) + trainpairs = ['{:05d}'.format(i) for i in np.where(split_list==1)[0]+1] + valpairs = ['{:05d}'.format(i) for i in np.where(split_list==2)[0]+1] + assert len(trainpairs)==22232 and len(valpairs)==640, "incorrect parsing of pairs in MPI-Sintel" + tosave = {'train': trainpairs, 'val': valpairs} + return tosave + +class FlyingThingsDataset(FlowDataset): + + def _prepare_data(self): + self.name = "FlyingThings" + self._set_root() + assert self.split in [f'{set_}_{pass_}pass{camstr}' for set_ in ['train','test','test1024'] for camstr in ['','_rightcam'] for pass_ in ['clean','final','all']] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[2])) + self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'optical_flow', pairname[0], 'OpticalFlowInto{f:s}_{i:04d}_{c:s}.pfm'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' )) + self.pairname_to_str = lambda pairname: os.path.join(pairname[3]+'pass', pairname[0], 'Into{f:s}_{i:04d}_{c:s}'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' )) + self.load_flow = _read_pfm_flow + + def _build_cache(self): + tosave = {} + # train and test splits for the different passes + for set_ in ['train', 'test']: + sroot = osp.join(self.root, 'optical_flow', set_.upper()) + fname_to_i = lambda f: int(f[len('OpticalFlowIntoFuture_'):-len('_L.pfm')]) + pp = [(osp.join(set_.upper(), d, s, 'into_future/left'),fname_to_i(fname)) for d in sorted(os.listdir(sroot)) for s in sorted(os.listdir(osp.join(sroot,d))) for fname in sorted(os.listdir(osp.join(sroot,d, s, 'into_future/left')))[:-1]] + pairs = [(a,i,i+1) for a,i in pp] + pairs += [(a.replace('into_future','into_past'),i+1,i) for a,i in pp] + assert len(pairs)=={'train': 40302, 'test': 7866}[set_], "incorrect parsing of pairs Flying Things" + for cam in ['left','right']: + camstr = '' if cam=='left' else f'_{cam}cam' + for pass_ in ['final', 'clean']: + tosave[f'{set_}_{pass_}pass{camstr}'] = [(a.replace('left',cam),i,j,pass_) for a,i,j in pairs] + tosave[f'{set_}_allpass{camstr}'] = tosave[f'{set_}_cleanpass{camstr}'] + tosave[f'{set_}_finalpass{camstr}'] + # test1024: this is the same split as unimatch 'validation' split + # see https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/datasets.py#L229 + test1024_nsamples = 1024 + alltest_nsamples = len(tosave['test_cleanpass']) # 7866 + stride = alltest_nsamples // test1024_nsamples + remove = alltest_nsamples % test1024_nsamples + for cam in ['left','right']: + camstr = '' if cam=='left' else f'_{cam}cam' + for pass_ in ['final','clean']: + tosave[f'test1024_{pass_}pass{camstr}'] = sorted(tosave[f'test_{pass_}pass{camstr}'])[:-remove][::stride] # warning, it was not sorted before + assert len(tosave['test1024_cleanpass'])==1024, "incorrect parsing of pairs in Flying Things" + tosave[f'test1024_allpass{camstr}'] = tosave[f'test1024_cleanpass{camstr}'] + tosave[f'test1024_finalpass{camstr}'] + return tosave + + +class MPISintelDataset(FlowDataset): + + def _prepare_data(self): + self.name = "MPISintel" + self._set_root() + assert self.split in [s+'_'+p for s in ['train','test','subval','subtrain'] for p in ['cleanpass','finalpass','allpass']] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1]+1)) + self.pairname_to_flowname = lambda pairname: None if pairname[0].startswith('test/') else osp.join(self.root, pairname[0].replace('/clean/','/flow/').replace('/final/','/flow/'), 'frame_{:04d}.flo'.format(pairname[1])) + self.pairname_to_str = lambda pairname: osp.join(pairname[0], 'frame_{:04d}'.format(pairname[1])) + self.load_flow = _read_flo_file + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root+'training/clean')) + trainpairs = [ (osp.join('training/clean', s),i) for s in trainseqs for i in range(1, len(os.listdir(self.root+'training/clean/'+s)))] + subvalseqs = ['temple_2','temple_3'] + subtrainseqs = [s for s in trainseqs if s not in subvalseqs] + subvalpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subvalseqs)] + subtrainpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subtrainseqs)] + testseqs = sorted(os.listdir(self.root+'test/clean')) + testpairs = [ (osp.join('test/clean', s),i) for s in testseqs for i in range(1, len(os.listdir(self.root+'test/clean/'+s)))] + assert len(trainpairs)==1041 and len(testpairs)==552 and len(subvalpairs)==98 and len(subtrainpairs)==943, "incorrect parsing of pairs in MPI-Sintel" + tosave = {} + tosave['train_cleanpass'] = trainpairs + tosave['test_cleanpass'] = testpairs + tosave['subval_cleanpass'] = subvalpairs + tosave['subtrain_cleanpass'] = subtrainpairs + for t in ['train','test','subval','subtrain']: + tosave[t+'_finalpass'] = [(p.replace('/clean/','/final/'),i) for p,i in tosave[t+'_cleanpass']] + tosave[t+'_allpass'] = tosave[t+'_cleanpass'] + tosave[t+'_finalpass'] + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, _time): + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, 'submission', self.pairname_to_str(pairname)+'.flo') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowFile(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split == 'test_allpass' + bundle_exe = "/nfs/data/ffs-3d/datasets/StereoFlow/MPI-Sintel/bundler/linux-x64/bundler" # eg + if os.path.isfile(bundle_exe): + cmd = f'{bundle_exe} "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at: "{outdir}/submission/bundled.lzma"') + else: + print('Could not find bundler executable for submission.') + print('Please download it and run:') + print(f' "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"') + +class SpringDataset(FlowDataset): + + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ['train','test','subtrain','subval'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4])) + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4]+(1 if pairname[2]=='FW' else -1))) + self.pairname_to_flowname = lambda pairname: None if pairname[0]=='test' else osp.join(self.root, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5') + self.pairname_to_str = lambda pairname: osp.join(pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}') + self.load_flow = _read_hdf5_flow + + def _build_cache(self): + # train + trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) + trainpairs = [] + for leftright in ['left','right']: + for fwbw in ['FW','BW']: + trainpairs += [('train',s,fwbw,leftright,int(f[len(f'flow_{fwbw}_{leftright}_'):-len('.flo5')])) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,f'flow_{fwbw}_{leftright}')))] + # test + testseqs = sorted(os.listdir( osp.join(self.root,'test'))) + testpairs = [] + for leftright in ['left','right']: + testpairs += [('test',s,'FW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]] + testpairs += [('test',s,'BW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])+1) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]] + # subtrain / subval + subtrainpairs = [p for p in trainpairs if p[1]!='0041'] + subvalpairs = [p for p in trainpairs if p[1]=='0041'] + assert len(trainpairs)==19852 and len(testpairs)==3960 and len(subtrainpairs)==19472 and len(subvalpairs)==380, "incorrect parsing of pairs in Spring" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + assert prediction.dtype==np.float32 + outfile = osp.join(outdir, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlo5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + exe = "{self.root}/flow_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/test/flow_submission.hdf5') + else: + print('Could not find flow_subsampling executable for submission.') + print('Please download it and run:') + print(f'cd "{outdir}/test"; .') + + +class Kitti12Dataset(FlowDataset): + + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ['train','test'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png') + self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/flow_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] + testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] + assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" + tosave = {'train': trainseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti12_flow_results.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti12_flow_results.zip') + + +class Kitti15Dataset(FlowDataset): + + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ['train','subtrain','subval','test'] + self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png') + self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/flow_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] + subtrainseqs = trainseqs[:-10] + subvalseqs = trainseqs[-10:] + testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] + assert len(trainseqs)==200 and len(subtrainseqs)==190 and len(subvalseqs)==10 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" + tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==3 + assert prediction.shape[2]==2 + outfile = os.path.join(outdir, 'flow', pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti15_flow_results.zip" flow' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti15_flow_results.zip') + + +import cv2 +def _read_numpy_flow(filename): + return np.load(filename) + +def _read_pfm_flow(filename): + f, _ = _read_pfm(filename) + assert np.all(f[:,:,2]==0.0) + return np.ascontiguousarray(f[:,:,:2]) + +TAG_FLOAT = 202021.25 # tag to check the sanity of the file +TAG_STRING = 'PIEH' # string containing the tag +MIN_WIDTH = 1 +MAX_WIDTH = 99999 +MIN_HEIGHT = 1 +MAX_HEIGHT = 99999 +def readFlowFile(filename): + """ + readFlowFile() reads a flow file into a 2-band np.array. + if does not exist, an IOError is raised. + if does not finish by '.flo' or the tag, the width, the height or the file's size is illegal, an Expcetion is raised. + ---- PARAMETERS ---- + filename: string containg the name of the file to read a flow + ---- OUTPUTS ---- + a np.array of dimension (height x width x 2) containing the flow of type 'float32' + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception("readFlowFile({:s}): filename must finish with '.flo'".format(filename)) + + # open the file and read it + with open(filename,'rb') as f: + # check tag + tag = struct.unpack('f',f.read(4))[0] + if tag != TAG_FLOAT: + raise Exception("flow_utils.readFlowFile({:s}): wrong tag".format(filename)) + # read dimension + w,h = struct.unpack('ii',f.read(8)) + if w < MIN_WIDTH or w > MAX_WIDTH: + raise Exception("flow_utils.readFlowFile({:s}: illegal width {:d}".format(filename,w)) + if h < MIN_HEIGHT or h > MAX_HEIGHT: + raise Exception("flow_utils.readFlowFile({:s}: illegal height {:d}".format(filename,h)) + flow = np.fromfile(f,'float32') + if not flow.shape == (h*w*2,): + raise Exception("flow_utils.readFlowFile({:s}: illegal size of the file".format(filename)) + flow.shape = (h,w,2) + return flow + +def writeFlowFile(flow,filename): + """ + writeFlowFile(flow,) write flow to the file . + if does not exist, an IOError is raised. + if does not finish with '.flo' or the flow has not 2 bands, an Exception is raised. + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to write + filename: string containg the name of the file to write a flow + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception("flow_utils.writeFlowFile(,{:s}): filename must finish with '.flo'".format(filename)) + + if not flow.shape[2:] == (2,): + raise Exception("flow_utils.writeFlowFile(,{:s}): must have 2 bands".format(filename)) + + + # open the file and write it + with open(filename,'wb') as f: + # write TAG + f.write( TAG_STRING.encode('utf-8') ) + # write dimension + f.write( struct.pack('ii',flow.shape[1],flow.shape[0]) ) + # write the flow + + flow.astype(np.float32).tofile(f) + +_read_flo_file = readFlowFile + +def _read_kitti_flow(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + valid = flow[:, :, 2]>0 + flow = flow[:, :, :2] + flow = (flow - 2 ** 15) / 64.0 + flow[~valid,0] = np.inf + flow[~valid,1] = np.inf + return flow +_read_hd1k_flow = _read_kitti_flow + + +def writeFlowKitti(filename, uv): + uv = 64.0 * uv + 2 ** 15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + +def writeFlo5File(flow, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5) + +def _read_hdf5_flow(filename): + flow = np.asarray(h5py.File(filename)['flow']) + flow[np.isnan(flow)] = np.inf # make invalid values as +inf + return flow.astype(np.float32) + +# flow visualization +RY = 15 +YG = 6 +GC = 4 +CB = 11 +BM = 13 +MR = 6 +UNKNOWN_THRESH = 1e9 + +def colorTest(): + """ + flow_utils.colorTest(): display an example of image showing the color encoding scheme + """ + import matplotlib.pylab as plt + truerange = 1 + h,w = 151,151 + trange = truerange*1.04 + s2 = round(h/2) + x,y = np.meshgrid(range(w),range(h)) + u = x*trange/s2-trange + v = y*trange/s2-trange + img = _computeColor(np.concatenate((u[:,:,np.newaxis],v[:,:,np.newaxis]),2)/trange/np.sqrt(2)) + plt.imshow(img) + plt.axis('off') + plt.axhline(round(h/2),color='k') + plt.axvline(round(w/2),color='k') + +def flowToColor(flow, maxflow=None, maxmaxflow=None, saturate=False): + """ + flow_utils.flowToColor(flow): return a color code flow field, normalized based on the maximum l2-norm of the flow + flow_utils.flowToColor(flow,maxflow): return a color code flow field, normalized by maxflow + ---- PARAMETERS ---- + flow: flow to display of shape (height x width x 2) + maxflow (default:None): if given, normalize the flow by its value, otherwise by the flow norm + maxmaxflow (default:None): if given, normalize the flow by the max of its value and the flow norm + ---- OUTPUT ---- + an np.array of shape (height x width x 3) of type uint8 containing a color code of the flow + """ + h,w,n = flow.shape + # check size of flow + assert n == 2, "flow_utils.flowToColor(flow): flow must have 2 bands" + # fix unknown flow + unknown_idx = np.max(np.abs(flow),2)>UNKNOWN_THRESH + flow[unknown_idx] = 0.0 + # compute max flow if needed + if maxflow is None: + maxflow = flowMaxNorm(flow) + if maxmaxflow is not None: + maxflow = min(maxmaxflow, maxflow) + # normalize flow + eps = np.spacing(1) # minimum positive float value to avoid division by 0 + # compute the flow + img = _computeColor(flow/(maxflow+eps), saturate=saturate) + # put black pixels in unknown location + img[ np.tile( unknown_idx[:,:,np.newaxis],[1,1,3]) ] = 0.0 + return img + +def flowMaxNorm(flow): + """ + flow_utils.flowMaxNorm(flow): return the maximum of the l2-norm of the given flow + ---- PARAMETERS ---- + flow: the flow + + ---- OUTPUT ---- + a float containing the maximum of the l2-norm of the flow + """ + return np.max( np.sqrt( np.sum( np.square( flow ) , 2) ) ) + +def _computeColor(flow, saturate=True): + """ + flow_utils._computeColor(flow): compute color codes for the flow field flow + + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to display + ---- OUTPUTS ---- + an np.array of dimension (height x width x 3) containing the color conversion of the flow + """ + # set nan to 0 + nanidx = np.isnan(flow[:,:,0]) + flow[nanidx] = 0.0 + + # colorwheel + ncols = RY + YG + GC + CB + BM + MR + nchans = 3 + colorwheel = np.zeros((ncols,nchans),'uint8') + col = 0; + #RY + colorwheel[:RY,0] = 255 + colorwheel[:RY,1] = [(255*i) // RY for i in range(RY)] + col += RY + # YG + colorwheel[col:col+YG,0] = [255 - (255*i) // YG for i in range(YG)] + colorwheel[col:col+YG,1] = 255 + col += YG + # GC + colorwheel[col:col+GC,1] = 255 + colorwheel[col:col+GC,2] = [(255*i) // GC for i in range(GC)] + col += GC + # CB + colorwheel[col:col+CB,1] = [255 - (255*i) // CB for i in range(CB)] + colorwheel[col:col+CB,2] = 255 + col += CB + # BM + colorwheel[col:col+BM,0] = [(255*i) // BM for i in range(BM)] + colorwheel[col:col+BM,2] = 255 + col += BM + # MR + colorwheel[col:col+MR,0] = 255 + colorwheel[col:col+MR,2] = [255 - (255*i) // MR for i in range(MR)] + + # compute utility variables + rad = np.sqrt( np.sum( np.square(flow) , 2) ) # magnitude + a = np.arctan2( -flow[:,:,1] , -flow[:,:,0]) / np.pi # angle + fk = (a+1)/2 * (ncols-1) # map [-1,1] to [0,ncols-1] + k0 = np.floor(fk).astype('int') + k1 = k0+1 + k1[k1==ncols] = 0 + f = fk-k0 + + if not saturate: + rad = np.minimum(rad,1) + + # compute the image + img = np.zeros( (flow.shape[0],flow.shape[1],nchans), 'uint8' ) + for i in range(nchans): + tmp = colorwheel[:,i].astype('float') + col0 = tmp[k0]/255 + col1 = tmp[k1]/255 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1-rad[idx]*(1-col[idx]) # increase saturation with radius + col[~idx] *= 0.75 # out of range + img[:,:,i] = (255*col*(1-nanidx.astype('float'))).astype('uint8') + + return img + +# flow dataset getter + +def get_train_dataset_flow(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace('(','Dataset(') + if augmentor: + dataset_str = dataset_str.replace(')',', augmentor=True)') + if crop_size is not None: + dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) + return eval(dataset_str) + +def get_test_datasets_flow(dataset_str): + dataset_str = dataset_str.replace('(','Dataset(') + return [eval(s) for s in dataset_str.split('+')] \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/stereoflow/datasets_stereo.py b/imcui/third_party/mast3r/dust3r/croco/stereoflow/datasets_stereo.py new file mode 100644 index 0000000000000000000000000000000000000000..dbdf841a6650afa71ae5782702902c79eba31a5c --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/stereoflow/datasets_stereo.py @@ -0,0 +1,674 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Dataset structure for stereo +# -------------------------------------------------------- + +import sys, os +import os.path as osp +import pickle +import numpy as np +from PIL import Image +import json +import h5py +from glob import glob +import cv2 + +import torch +from torch.utils import data + +from .augmentor import StereoAugmentor + + + +dataset_to_root = { + 'CREStereo': './data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/', + 'SceneFlow': './data/stereoflow//SceneFlow/', + 'ETH3DLowRes': './data/stereoflow/eth3d_lowres/', + 'Booster': './data/stereoflow/booster_gt/', + 'Middlebury2021': './data/stereoflow/middlebury/2021/data/', + 'Middlebury2014': './data/stereoflow/middlebury/2014/', + 'Middlebury2006': './data/stereoflow/middlebury/2006/', + 'Middlebury2005': './data/stereoflow/middlebury/2005/train/', + 'MiddleburyEval3': './data/stereoflow/middlebury/MiddEval3/', + 'Spring': './data/stereoflow/spring/', + 'Kitti15': './data/stereoflow/kitti-stereo-2015/', + 'Kitti12': './data/stereoflow/kitti-stereo-2012/', +} +cache_dir = "./data/stereoflow/datasets_stereo_cache/" + + +in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) +in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) +def img_to_tensor(img): + img = torch.from_numpy(img).permute(2, 0, 1).float() / 255. + img = (img-in1k_mean)/in1k_std + return img +def disp_to_tensor(disp): + return torch.from_numpy(disp)[None,:,:] + +class StereoDataset(data.Dataset): + + def __init__(self, split, augmentor=False, crop_size=None, totensor=True): + self.split = split + if not augmentor: assert crop_size is None + if crop_size: assert augmentor + self.crop_size = crop_size + self.augmentor_str = augmentor + self.augmentor = StereoAugmentor(crop_size) if augmentor else None + self.totensor = totensor + self.rmul = 1 # keep track of rmul + self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len(self.pairnames) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + Limgname = self.pairname_to_Limgname(pairname) + Rimgname = self.pairname_to_Rimgname(pairname) + Ldispname = self.pairname_to_Ldispname(pairname) if self.pairname_to_Ldispname is not None else None + + # load images and disparities + Limg = _read_img(Limgname) + Rimg = _read_img(Rimgname) + disp = self.load_disparity(Ldispname) if Ldispname is not None else None + + # sanity check + if disp is not None: assert np.all(disp>0) or self.name=="Spring", (self.name, pairname, Ldispname) + + # apply augmentations + if self.augmentor is not None: + Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) + + if self.totensor: + Limg = img_to_tensor(Limg) + Rimg = img_to_tensor(Rimg) + if disp is None: + disp = torch.tensor([]) # to allow dataloader batching with default collate_gn + else: + disp = disp_to_tensor(disp) + + return Limg, Rimg, disp, str(pairname) + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f'{self.__class__.__name__}_{self.split}' + + def __repr__(self): + s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' + if self.rmul==1: + s+=f'\n\tnum pairs: {len(self.pairnames)}' + else: + s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name+'.pkl') + if osp.isfile(cache_file): + with open(cache_file, 'rb') as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, 'wb') as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + +class CREStereoDataset(StereoDataset): + + def _prepare_data(self): + self.name = 'CREStereo' + self._set_root() + assert self.split in ['train'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_left.jpg') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'_right.jpg') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname+'_left.disp.png') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_crestereo_disp + + + def _build_cache(self): + allpairs = [s+'/'+f[:-len('_left.jpg')] for s in sorted(os.listdir(self.root)) for f in sorted(os.listdir(self.root+'/'+s)) if f.endswith('_left.jpg')] + assert len(allpairs)==200000, "incorrect parsing of pairs in CreStereo" + tosave = {'train': allpairs} + return tosave + +class SceneFlowDataset(StereoDataset): + + def _prepare_data(self): + self.name = "SceneFlow" + self._set_root() + assert self.split in ['train_finalpass','train_cleanpass','train_allpass','test_finalpass','test_cleanpass','test_allpass','test1of100_cleanpass','test1of100_finalpass'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/left/','/right/') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname).replace('/frames_finalpass/','/disparity/').replace('/frames_cleanpass/','/disparity/')[:-4]+'.pfm' + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_sceneflow_disp + + def _build_cache(self): + trainpairs = [] + # driving + pairs = sorted(glob(self.root+'Driving/frames_finalpass/*/*/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # monkaa + pairs = sorted(glob(self.root+'Monkaa/frames_finalpass/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # flyingthings + pairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png')) + pairs = list(map(lambda x: x[len(self.root):], pairs)) + assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" + testpairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TEST/*/*/left/*.png')) + testpairs = list(map(lambda x: x[len(self.root):], testpairs)) + assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" + test1of100pairs = testpairs[::100] + assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" + # all + tosave = {'train_finalpass': trainpairs, + 'train_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), trainpairs)), + 'test_finalpass': testpairs, + 'test_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), testpairs)), + 'test1of100_finalpass': test1of100pairs, + 'test1of100_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), test1of100pairs)), + } + tosave['train_allpass'] = tosave['train_finalpass']+tosave['train_cleanpass'] + tosave['test_allpass'] = tosave['test_finalpass']+tosave['test_cleanpass'] + return tosave + +class Md21Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2021" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/im0','/im1')) + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp0.pfm') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + #trainpairs += [s+'/im0.png'] # we should remove it, it is included as such in other lightings + trainpairs += [s+'/ambient/'+b+'/'+a for b in sorted(os.listdir(osp.join(self.root,s,'ambient'))) for a in sorted(os.listdir(osp.join(self.root,s,'ambient',b))) if a.startswith('im0')] + assert len(trainpairs)==355 + subtrainpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[:-2])] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[-2:])] + assert len(subtrainpairs)==335 and len(subvalpairs)==20, "incorrect parsing of pairs in Middlebury 2021" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md14Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2014" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'disp0.pfm') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + trainpairs += [s+'/im1.png',s+'/im1E.png',s+'/im1L.png'] + assert len(trainpairs)==138 + valseqs = ['Umbrella-imperfect','Vintage-perfect'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==132 and len(subvalpairs)==6, "incorrect parsing of pairs in Middlebury 2014" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md06Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2006" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') + self.load_disparity = _read_middlebury20052006_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ['Illum1','Illum2','Illum3']: + for e in ['Exp0','Exp1','Exp2']: + trainpairs.append(osp.join(s,i,e,'view1.png')) + assert len(trainpairs)==189 + valseqs = ['Rocks1','Wood2'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==171 and len(subvalpairs)==18, "incorrect parsing of pairs in Middlebury 2006" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class Md05Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Middlebury2005" + self._set_root() + assert self.split in ['train','subtrain','subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury20052006_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ['Illum1','Illum2','Illum3']: + for e in ['Exp0','Exp1','Exp2']: + trainpairs.append(osp.join(s,i,e,'view1.png')) + assert len(trainpairs)==54, "incorrect parsing of pairs in Middlebury 2005" + valseqs = ['Reindeer'] + assert all(s in seqs for s in valseqs) + subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] + subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] + assert len(subtrainpairs)==45 and len(subvalpairs)==9, "incorrect parsing of pairs in Middlebury 2005" + tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + +class MdEval3Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "MiddleburyEval3" + self._set_root() + assert self.split in [s+'_'+r for s in ['train','subtrain','subval','test','all'] for r in ['full','half','quarter']] + if self.split.endswith('_full'): + self.root = self.root.replace('/MiddEval3','/MiddEval3_F') + elif self.split.endswith('_half'): + self.root = self.root.replace('/MiddEval3','/MiddEval3_H') + else: + assert self.split.endswith('_quarter') + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') + self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname, 'disp0GT.pfm') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_middlebury_disp + # for submission only + self.submission_methodname = "CroCo-Stereo" + self.submission_sresolution = 'F' if self.split.endswith('_full') else ('H' if self.split.endswith('_half') else 'Q') + + def _build_cache(self): + trainpairs = ['train/'+s for s in sorted(os.listdir(self.root+'train/'))] + testpairs = ['test/'+s for s in sorted(os.listdir(self.root+'test/'))] + subvalpairs = trainpairs[-1:] + subtrainpairs = trainpairs[:-1] + allpairs = trainpairs+testpairs + assert len(trainpairs)==15 and len(testpairs)==15 and len(subvalpairs)==1 and len(subtrainpairs)==14 and len(allpairs)==30, "incorrect parsing of pairs in Middlebury Eval v3" + tosave = {} + for r in ['full','half','quarter']: + tosave.update(**{'train_'+r: trainpairs, 'subtrain_'+r: subtrainpairs, 'subval_'+r: subvalpairs, 'test_'+r: testpairs, 'all_'+r: allpairs}) + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname.split('/')[0].replace('train','training')+self.submission_sresolution, pairname.split('/')[1], 'disp0'+self.submission_methodname+'.pfm') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = os.path.join( os.path.dirname(outfile), "time"+self.submission_methodname+'.txt') + with open(timefile, 'w') as fid: + fid.write(str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/{self.submission_methodname}.zip') + +class ETH3DLowResDataset(StereoDataset): + + def _prepare_data(self): + self.name = "ETH3DLowRes" + self._set_root() + assert self.split in ['train','test','subtrain','subval','all'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: None if pairname.startswith('test/') else osp.join(self.root, pairname.replace('train/','train_gt/'), 'disp0GT.pfm') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_eth3d_disp + self.has_constant_resolution = False + + def _build_cache(self): + trainpairs = ['train/' + s for s in sorted(os.listdir(self.root+'train/'))] + testpairs = ['test/' + s for s in sorted(os.listdir(self.root+'test/'))] + assert len(trainpairs) == 27 and len(testpairs) == 20, "incorrect parsing of pairs in ETH3D Low Res" + subvalpairs = ['train/delivery_area_3s','train/electro_3l','train/playground_3l'] + assert all(p in trainpairs for p in subvalpairs) + subtrainpairs = [p for p in trainpairs if not p in subvalpairs] + assert len(subvalpairs)==3 and len(subtrainpairs)==24, "incorrect parsing of pairs in ETH3D Low Res" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs, 'all': trainpairs+testpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, 'low_res_two_view', pairname.split('/')[1]+'.pfm') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = outfile[:-4]+'.txt' + with open(timefile, 'w') as fid: + fid.write('runtime '+str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip') + +class BoosterDataset(StereoDataset): + + def _prepare_data(self): + self.name = "Booster" + self._set_root() + assert self.split in ['train_balanced','test_balanced','subtrain_balanced','subval_balanced'] # we use only the balanced version + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/camera_00/','/camera_02/') + self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), '../disp_00.npy') # same images with different colors, same gt per sequence + self.pairname_to_str = lambda pairname: pairname[:-4].replace('/camera_00/','/') + self.load_disparity = _read_booster_disp + + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root+'train/balanced')) + trainpairs = ['train/balanced/'+s+'/camera_00/'+imname for s in trainseqs for imname in sorted(os.listdir(self.root+'train/balanced/'+s+'/camera_00/'))] + testpairs = ['test/balanced/'+s+'/camera_00/'+imname for s in sorted(os.listdir(self.root+'test/balanced')) for imname in sorted(os.listdir(self.root+'test/balanced/'+s+'/camera_00/'))] + assert len(trainpairs) == 228 and len(testpairs) == 191 + subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] + subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] + # warning: if we do validation split, we should split scenes!!! + tosave = {'train_balanced': trainpairs, 'test_balanced': testpairs, 'subtrain_balanced': subtrainpairs, 'subval_balanced': subvalpairs,} + return tosave + +class SpringDataset(StereoDataset): + + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ['train', 'test', 'subtrain', 'subval'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'.png').replace('frame_right','').replace('frame_left','frame_right').replace('','frame_left') + self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_hdf5_disp + + def _build_cache(self): + trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) + trainpairs = [osp.join('train',s,'frame_left',f[:-4]) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,'frame_left')))] + testseqs = sorted(os.listdir( osp.join(self.root,'test'))) + testpairs = [osp.join('test',s,'frame_left',f[:-4]) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,'frame_left')))] + testpairs += [p.replace('frame_left','frame_right') for p in testpairs] + """maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" + subtrainpairs = [p for p in trainpairs if p.split('/')[1]!='0041'] + subvalpairs = [p for p in trainpairs if p.split('/')[1]=='0041'] + assert len(trainpairs)==5000 and len(testpairs)==2000 and len(subtrainpairs)==4904 and len(subvalpairs)==96, "incorrect parsing of pairs in Spring" + tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + writeDsp5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + exe = "{self.root}/disp1_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + else: + print('Could not find disp1_subsampling executable for submission.') + print('Please download it and run:') + print(f'cd "{outdir}/test"; .') + +class Kitti12Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ['train','test'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/colored_1/')+'_10.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/disp_occ/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] + testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] + assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" + tosave = {'train': trainseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype('uint16') + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti12_results.zip') + +class Kitti15Dataset(StereoDataset): + + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ['train','subtrain','subval','test'] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/image_3/')+'_10.png') + self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/disp_occ_0/')+'_10.png') + self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] + subtrainseqs = trainseqs[:-5] + subvalseqs = trainseqs[-5:] + testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] + assert len(trainseqs)==200 and len(subtrainseqs)==195 and len(subvalseqs)==5 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" + tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim==2 + assert prediction.dtype==np.float32 + outfile = os.path.join(outdir, 'disp_0', pairname.split('/')[-1]+'_10.png') + os.makedirs( os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype('uint16') + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split=='test' + cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at {outdir}/kitti15_results.zip') + + +### auxiliary functions + +def _read_img(filename): + # convert to RGB for scene flow finalpass data + img = np.asarray(Image.open(filename).convert('RGB')) + return img + +def _read_booster_disp(filename): + disp = np.load(filename) + disp[disp==0.0] = np.inf + return disp + +def _read_png_disp(filename, coef=1.0): + disp = np.asarray(Image.open(filename)) + disp = disp.astype(np.float32) / coef + disp[disp==0.0] = np.inf + return disp + +def _read_pfm_disp(filename): + disp = np.ascontiguousarray(_read_pfm(filename)[0]) + disp[disp<=0] = np.inf # eg /nfs/data/ffs-3d/datasets/middlebury/2014/Shopvac-imperfect/disp0.pfm + return disp + +def _read_npy_disp(filename): + return np.load(filename) + +def _read_crestereo_disp(filename): return _read_png_disp(filename, coef=32.0) +def _read_middlebury20052006_disp(filename): return _read_png_disp(filename, coef=1.0) +def _read_kitti_disp(filename): return _read_png_disp(filename, coef=256.0) +_read_sceneflow_disp = _read_pfm_disp +_read_eth3d_disp = _read_pfm_disp +_read_middlebury_disp = _read_pfm_disp +_read_carla_disp = _read_pfm_disp +_read_tartanair_disp = _read_npy_disp + +def _read_hdf5_disp(filename): + disp = np.asarray(h5py.File(filename)['disparity']) + disp[np.isnan(disp)] = np.inf # make invalid values as +inf + #disp[disp==0.0] = np.inf # make invalid values as +inf + return disp.astype(np.float32) + +import re +def _read_pfm(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + +def writeDsp5File(disp, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) + + +# disp visualization + +def vis_disparity(disp, m=None, M=None): + if m is None: m = disp.min() + if M is None: M = disp.max() + disp_vis = (disp - m) / (M-m) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + return disp_vis + +# dataset getter + +def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace('(','Dataset(') + if augmentor: + dataset_str = dataset_str.replace(')',', augmentor=True)') + if crop_size is not None: + dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) + return eval(dataset_str) + +def get_test_datasets_stereo(dataset_str): + dataset_str = dataset_str.replace('(','Dataset(') + return [eval(s) for s in dataset_str.split('+')] \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/stereoflow/engine.py b/imcui/third_party/mast3r/dust3r/croco/stereoflow/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..c057346b99143bf6b9c4666a58215b2b91aca7a6 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/stereoflow/engine.py @@ -0,0 +1,280 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main function for training one epoch or testing +# -------------------------------------------------------- + +import math +import sys +from typing import Iterable +import numpy as np +import torch +import torchvision + +from utils import misc as misc + + +def split_prediction_conf(predictions, with_conf=False): + if not with_conf: + return predictions, None + conf = predictions[:,-1:,:,:] + predictions = predictions[:,:-1,:,:] + return predictions, conf + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + log_writer=None, print_freq = 20, + args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + + accum_iter = args.accum_iter + + optimizer.zero_grad() + + details = {} + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if args.img_per_epoch: + iter_per_epoch = args.img_per_epoch // args.batch_size + int(args.img_per_epoch % args.batch_size > 0) + assert len(data_loader) >= iter_per_epoch, 'Dataset is too small for so many iterations' + len_data_loader = iter_per_epoch + else: + len_data_loader, iter_per_epoch = len(data_loader), None + + for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_logger.log_every(data_loader, print_freq, header, max_iter=iter_per_epoch)): + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, data_iter_step / len_data_loader + epoch, args) + + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + prediction = model(image1, image2) + prediction, conf = split_prediction_conf(prediction, criterion.with_conf) + batch_metrics = metrics(prediction.detach(), gt) + loss = criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf) + + loss_value = loss.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + for k,v in batch_metrics.items(): + metric_logger.update(**{k: v.item()}) + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + #if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) + time_to_log = ((data_iter_step + 1) % (args.tboard_log_step * accum_iter) == 0 or data_iter_step == len_data_loader-1) + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and time_to_log: + epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) + # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. + log_writer.add_scalar('train/loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('lr', lr, epoch_1000x) + for k,v in batch_metrics.items(): + log_writer.add_scalar('train/'+k, v.item(), epoch_1000x) + + # gather the stats from all processes + #if args.distributed: metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def validate_one_epoch(model: torch.nn.Module, + criterion: torch.nn.Module, + metrics: torch.nn.Module, + data_loaders: list[Iterable], + device: torch.device, + epoch: int, + log_writer=None, + args=None): + + model.eval() + metric_loggers = [] + header = 'Epoch: [{}]'.format(epoch) + print_freq = 20 + + conf_mode = args.tile_conf_mode + crop = args.crop + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + results = {} + dnames = [] + image1, image2, gt, prediction = None, None, None, None + for didx, data_loader in enumerate(data_loaders): + dname = str(data_loader.dataset) + dnames.append(dname) + metric_loggers.append(misc.MetricLogger(delimiter=" ")) + for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_loggers[didx].log_every(data_loader, print_freq, header)): + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + if dname.startswith('Spring'): + assert gt.size(2)==image1.size(2)*2 and gt.size(3)==image1.size(3)*2 + gt = (gt[:,:,0::2,0::2] + gt[:,:,0::2,1::2] + gt[:,:,1::2,0::2] + gt[:,:,1::2,1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones + + with torch.inference_mode(): + prediction, tiled_loss, c = tiled_pred(model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf) + batch_metrics = metrics(prediction.detach(), gt) + loss = criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c) + loss_value = loss.item() + metric_loggers[didx].update(loss_tiled=tiled_loss.item()) + metric_loggers[didx].update(**{f'loss': loss_value}) + for k,v in batch_metrics.items(): + metric_loggers[didx].update(**{dname+'_' + k: v.item()}) + + results = {k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()} + if len(dnames)>1: + for k in batch_metrics.keys(): + results['AVG_'+k] = sum(results[dname+'_'+k] for dname in dnames) / len(dnames) + + if log_writer is not None : + epoch_1000x = int((1 + epoch) * 1000) + for k,v in results.items(): + log_writer.add_scalar('val/'+k, v, epoch_1000x) + + print("Averaged stats:", results) + return results + +import torch.nn.functional as F +def _resize_img(img, new_size): + return F.interpolate(img, size=new_size, mode='bicubic', align_corners=False) +def _resize_stereo_or_flow(data, new_size): + assert data.ndim==4 + assert data.size(1) in [1,2] + scale_x = new_size[1]/float(data.size(3)) + out = F.interpolate(data, size=new_size, mode='bicubic', align_corners=False) + out[:,0,:,:] *= scale_x + if out.size(1)==2: + scale_y = new_size[0]/float(data.size(2)) + out[:,1,:,:] *= scale_y + print(scale_x, new_size, data.shape) + return out + + +@torch.no_grad() +def tiled_pred(model, criterion, img1, img2, gt, + overlap=0.5, bad_crop_thr=0.05, + downscale=False, crop=512, ret='loss', + conf_mode='conf_expsigmoid_10_5', with_conf=False, + return_time=False): + + # for each image, we are going to run inference on many overlapping patches + # then, all predictions will be weighted-averaged + if gt is not None: + B, C, H, W = gt.shape + else: + B, _, H, W = img1.shape + C = model.head.num_channels-int(with_conf) + win_height, win_width = crop[0], crop[1] + + # upscale to be larger than the crop + do_change_scale = H= window and 0 <= overlap < 1, (total, window, overlap) + num_windows = 1 + int(np.ceil( (total - window) / ((1-overlap) * window) )) + offsets = np.linspace(0, total-window, num_windows).round().astype(int) + yield from (slice(x, x+window) for x in offsets) + +def _crop(img, sy, sx): + B, THREE, H, W = img.shape + if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: + return img[:,:,sy,sx] + l, r = max(0,-sx.start), max(0,sx.stop-W) + t, b = max(0,-sy.start), max(0,sy.stop-H) + img = torch.nn.functional.pad(img, (l,r,t,b), mode='constant') + return img[:, :, slice(sy.start+t,sy.stop+t), slice(sx.start+l,sx.stop+l)] \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/stereoflow/test.py b/imcui/third_party/mast3r/dust3r/croco/stereoflow/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0248e56664c769752595af251e1eadcfa3a479d9 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/stereoflow/test.py @@ -0,0 +1,216 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main test function +# -------------------------------------------------------- + +import os +import argparse +import pickle +from PIL import Image +import numpy as np +from tqdm import tqdm + +import torch +from torch.utils.data import DataLoader + +import utils.misc as misc +from models.croco_downstream import CroCoDownstreamBinocular +from models.head_downstream import PixelwiseTaskWithDPT + +from stereoflow.criterion import * +from stereoflow.datasets_stereo import get_test_datasets_stereo +from stereoflow.datasets_flow import get_test_datasets_flow +from stereoflow.engine import tiled_pred + +from stereoflow.datasets_stereo import vis_disparity +from stereoflow.datasets_flow import flowToColor + +def get_args_parser(): + parser = argparse.ArgumentParser('Test CroCo models on stereo/flow', add_help=False) + # important argument + parser.add_argument('--model', required=True, type=str, help='Path to the model to evaluate') + parser.add_argument('--dataset', required=True, type=str, help="test dataset (there can be multiple dataset separated by a +)") + # tiling + parser.add_argument('--tile_conf_mode', type=str, default='', help='Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint') + parser.add_argument('--tile_overlap', type=float, default=0.7, help='overlap between tiles') + # save (it will automatically go to _/_) + parser.add_argument('--save', type=str, nargs='+', default=[], + help='what to save: \ + metrics (pickle file), \ + pred (raw prediction save as torch tensor), \ + visu (visualization in png of each prediction), \ + err10 (visualization in png of the error clamp at 10 for each prediction), \ + submission (submission file)') + # other (no impact) + parser.add_argument('--num_workers', default=4, type=int) + return parser + + +def _load_model_and_criterion(model_path, do_load_metrics, device): + print('loading model from', model_path) + assert os.path.isfile(model_path) + ckpt = torch.load(model_path, 'cpu') + + ckpt_args = ckpt['args'] + task = ckpt_args.task + tile_conf_mode = ckpt_args.tile_conf_mode + num_channels = {'stereo': 1, 'flow': 2}[task] + with_conf = eval(ckpt_args.criterion).with_conf + if with_conf: num_channels += 1 + print('head: PixelwiseTaskWithDPT()') + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + print('croco_args:', ckpt_args.croco_args) + model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) + msg = model.load_state_dict(ckpt['model'], strict=True) + model.eval() + model = model.to(device) + + if do_load_metrics: + if task=='stereo': + metrics = StereoDatasetMetrics().to(device) + else: + metrics = FlowDatasetMetrics().to(device) + else: + metrics = None + + return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode + + +def _save_batch(pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None): + + for i in range(len(pairnames)): + + pairname = eval(pairnames[i]) if pairnames[i].startswith('(') else pairnames[i] # unbatch pairname + fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) + os.makedirs(os.path.dirname(fname), exist_ok=True) + + predi = pred[i,...] + if gt is not None: gti = gt[i,...] + + if 'pred' in save: + torch.save(predi.squeeze(0).cpu(), fname+'_pred.pth') + + if 'visu' in save: + if task=='stereo': + disparity = predi.permute((1,2,0)).squeeze(2).cpu().numpy() + m,M = None + if gt is not None: + mask = torch.isfinite(gti) + m = gt[mask].min() + M = gt[mask].max() + img_disparity = vis_disparity(disparity, m=m, M=M) + Image.fromarray(img_disparity).save(fname+'_pred.png') + else: + # normalize flowToColor according to the maxnorm of gt (or prediction if not available) + flowNorm = torch.sqrt(torch.sum( (gti if gt is not None else predi)**2, dim=0)).max().item() + imgflow = flowToColor(predi.permute((1,2,0)).cpu().numpy(), maxflow=flowNorm) + Image.fromarray(imgflow).save(fname+'_pred.png') + + if 'err10' in save: + assert gt is not None + L2err = torch.sqrt(torch.sum( (gti-predi)**2, dim=0)) + valid = torch.isfinite(gti[0,:,:]) + L2err[~valid] = 0.0 + L2err = torch.clamp(L2err, max=10.0) + red = (L2err*255.0/10.0).to(dtype=torch.uint8)[:,:,None] + zer = torch.zeros_like(red) + imgerr = torch.cat( (red,zer,zer), dim=2).cpu().numpy() + Image.fromarray(imgerr).save(fname+'_err10.png') + + if 'submission' in save: + assert submission_dir is not None + predi_np = predi.permute(1,2,0).squeeze(2).cpu().numpy() # transform into HxWx2 for flow or HxW for stereo + dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) + +def main(args): + + # load the pretrained model and metrics + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + model, metrics, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion(args.model, 'metrics' in args.save, device) + if args.tile_conf_mode=='': args.tile_conf_mode = tile_conf_mode + + # load the datasets + datasets = (get_test_datasets_stereo if task=='stereo' else get_test_datasets_flow)(args.dataset) + dataloaders = [DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for dataset in datasets] + + # run + for i,dataloader in enumerate(dataloaders): + dataset = datasets[i] + dstr = args.dataset.split('+')[i] + + outdir = args.model+'_'+misc.filename(dstr) + if 'metrics' in args.save and len(args.save)==1: + fname = os.path.join(outdir, f'conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl') + if os.path.isfile(fname) and len(args.save)==1: + print(' metrics already compute in '+fname) + with open(fname, 'rb') as fid: + results = pickle.load(fid) + for k,v in results.items(): + print('{:s}: {:.3f}'.format(k, v)) + continue + + if 'submission' in args.save: + dirname = f'submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}' + submission_dir = os.path.join(outdir, dirname) + else: + submission_dir = None + + print('') + print('saving {:s} in {:s}'.format('+'.join(args.save), outdir)) + print(repr(dataset)) + + if metrics is not None: + metrics.reset() + + for data_iter_step, (image1, image2, gt, pairnames) in enumerate(tqdm(dataloader)): + + do_flip = (task=='stereo' and dstr.startswith('Spring') and any("right" in p for p in pairnames)) # we flip the images and will flip the prediction after as we assume img1 is on the left + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) if gt.numel()>0 else None # special case for test time + if do_flip: + assert all("right" in p for p in pairnames) + image1 = image1.flip(dims=[3]) # this is already the right frame, let's flip it + image2 = image2.flip(dims=[3]) + gt = gt # that is ok + + with torch.inference_mode(): + pred, _, _, time = tiled_pred(model, None, image1, image2, None if dataset.name=='Spring' else gt, conf_mode=args.tile_conf_mode, overlap=args.tile_overlap, crop=cropsize, with_conf=with_conf, return_time=True) + + if do_flip: + pred = pred.flip(dims=[3]) + + if metrics is not None: + metrics.add_batch(pred, gt) + + if any(k in args.save for k in ['pred','visu','err10','submission']): + _save_batch(pred, gt, pairnames, dataset, task, args.save, outdir, time, submission_dir=submission_dir) + + + # print + if metrics is not None: + results = metrics.get_results() + for k,v in results.items(): + print('{:s}: {:.3f}'.format(k, v)) + + # save if needed + if 'metrics' in args.save: + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, 'wb') as fid: + pickle.dump(results, fid) + print('metrics saved in', fname) + + # finalize submission if needed + if 'submission' in args.save: + dataset.finalize_submission(submission_dir) + + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/stereoflow/train.py b/imcui/third_party/mast3r/dust3r/croco/stereoflow/train.py new file mode 100644 index 0000000000000000000000000000000000000000..91f2414ffbe5ecd547d31c0e2455478d402719d6 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/stereoflow/train.py @@ -0,0 +1,253 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main training function +# -------------------------------------------------------- + +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time + +import torch +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from torch.utils.data import DataLoader + +import utils +import utils.misc as misc +from utils.misc import NativeScalerWithGradNormCount as NativeScaler +from models.croco_downstream import CroCoDownstreamBinocular, croco_args_from_ckpt +from models.pos_embed import interpolate_pos_embed +from models.head_downstream import PixelwiseTaskWithDPT + +from stereoflow.datasets_stereo import get_train_dataset_stereo, get_test_datasets_stereo +from stereoflow.datasets_flow import get_train_dataset_flow, get_test_datasets_flow +from stereoflow.engine import train_one_epoch, validate_one_epoch +from stereoflow.criterion import * + + +def get_args_parser(): + # prepare subparsers + parser = argparse.ArgumentParser('Finetuning CroCo models on stereo or flow', add_help=False) + subparsers = parser.add_subparsers(title="Task (stereo or flow)", dest="task", required=True) + parser_stereo = subparsers.add_parser('stereo', help='Training stereo model') + parser_flow = subparsers.add_parser('flow', help='Training flow model') + def add_arg(name_or_flags, default=None, default_stereo=None, default_flow=None, **kwargs): + if default is not None: assert default_stereo is None and default_flow is None, "setting default makes default_stereo and default_flow disabled" + parser_stereo.add_argument(name_or_flags, default=default if default is not None else default_stereo, **kwargs) + parser_flow.add_argument(name_or_flags, default=default if default is not None else default_flow, **kwargs) + # output dir + add_arg('--output_dir', required=True, type=str, help='path where to save, if empty, automatically created') + # model + add_arg('--crop', type=int, nargs = '+', default_stereo=[352, 704], default_flow=[320, 384], help = "size of the random image crops used during training.") + add_arg('--pretrained', required=True, type=str, help="Load pretrained model (required as croco arguments come from there)") + # criterion + add_arg('--criterion', default_stereo='LaplacianLossBounded2()', default_flow='LaplacianLossBounded()', type=str, help='string to evaluate to get criterion') + add_arg('--bestmetric', default_stereo='avgerr', default_flow='EPE', type=str) + # dataset + add_arg('--dataset', type=str, required=True, help="training set") + # training + add_arg('--seed', default=0, type=int, help='seed') + add_arg('--batch_size', default_stereo=6, default_flow=8, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + add_arg('--epochs', default=32, type=int, help='number of training epochs') + add_arg('--img_per_epoch', type=int, default=None, help='Fix the number of images seen in an epoch (None means use all training pairs)') + add_arg('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + add_arg('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') + add_arg('--lr', type=float, default_stereo=3e-5, default_flow=2e-5, metavar='LR', help='learning rate (absolute lr)') + add_arg('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') + add_arg('--warmup_epochs', type=int, default=1, metavar='N', help='epochs to warmup LR') + add_arg('--optimizer', default='AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))', type=str, + help="Optimizer from torch.optim [ default: AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) ]") + add_arg('--amp', default=0, type=int, choices=[0,1], help='enable automatic mixed precision training') + # validation + add_arg('--val_dataset', type=str, default='', help="Validation sets, multiple separated by + (empty string means that no validation is performed)") + add_arg('--tile_conf_mode', type=str, default_stereo='conf_expsigmoid_15_3', default_flow='conf_expsigmoid_10_5', help='Weights for tile aggregation') + add_arg('--val_overlap', default=0.7, type=float, help='Overlap value for the tiling') + # others + add_arg('--num_workers', default=8, type=int) + add_arg('--eval_every', type=int, default=1, help='Val loss evaluation frequency') + add_arg('--save_every', type=int, default=1, help='Save checkpoint frequency') + add_arg('--start_from', type=str, default=None, help='Start training using weights from an other model (eg for finetuning)') + add_arg('--tboard_log_step', type=int, default=100, help='Log to tboard every so many steps') + add_arg('--dist_url', default='env://', help='url used to set up distributed training') + + return parser + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + num_tasks = misc.get_world_size() + + assert os.path.isfile(args.pretrained) + print("output_dir: "+args.output_dir) + os.makedirs(args.output_dir, exist_ok=True) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + # Metrics / criterion + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + metrics = (StereoMetrics if args.task=='stereo' else FlowMetrics)().to(device) + criterion = eval(args.criterion).to(device) + print('Criterion: ', args.criterion) + + # Prepare model + assert os.path.isfile(args.pretrained) + ckpt = torch.load(args.pretrained, 'cpu') + croco_args = croco_args_from_ckpt(ckpt) + croco_args['img_size'] = (args.crop[0], args.crop[1]) + print('Croco args: '+str(croco_args)) + args.croco_args = croco_args # saved for test time + # prepare head + num_channels = {'stereo': 1, 'flow': 2}[args.task] + if criterion.with_conf: num_channels += 1 + print(f'Building head PixelwiseTaskWithDPT() with {num_channels} channel(s)') + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + # build model and load pretrained weights + model = CroCoDownstreamBinocular(head, **croco_args) + interpolate_pos_embed(model, ckpt['model']) + msg = model.load_state_dict(ckpt['model'], strict=False) + print(msg) + + total_params = sum(p.numel() for p in model.parameters()) + total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Total params: {total_params}") + print(f"Total params trainable: {total_params_trainable}") + model_without_ddp = model.to(device) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + print("lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], static_graph=True) + model_without_ddp = model.module + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) + optimizer = eval(f"torch.optim.{args.optimizer}") + print(optimizer) + loss_scaler = NativeScaler() + + # automatic restart + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + if not args.resume and args.start_from: + print(f"Starting from an other model's weights: {args.start_from}") + best_so_far = None + args.start_epoch = 0 + ckpt = torch.load(args.start_from, 'cpu') + msg = model_without_ddp.load_state_dict(ckpt['model'], strict=False) + print(msg) + else: + best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if best_so_far is None: best_so_far = np.inf + + # tensorboard + log_writer = None + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir, purge_step=args.start_epoch*1000) + + # dataset and loader + print('Building Train Data loader for dataset: ', args.dataset) + train_dataset = (get_train_dataset_stereo if args.task=='stereo' else get_train_dataset_flow)(args.dataset, crop_size=args.crop) + def _print_repr_dataset(d): + if isinstance(d, torch.utils.data.dataset.ConcatDataset): + for dd in d.datasets: + _print_repr_dataset(dd) + else: + print(repr(d)) + _print_repr_dataset(train_dataset) + print(' total length:', len(train_dataset)) + if args.distributed: + sampler_train = torch.utils.data.DistributedSampler( + train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.RandomSampler(train_dataset) + data_loader_train = torch.utils.data.DataLoader( + train_dataset, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + if args.val_dataset=='': + data_loaders_val = None + else: + print('Building Val Data loader for datasets: ', args.val_dataset) + val_datasets = (get_test_datasets_stereo if args.task=='stereo' else get_test_datasets_flow)(args.val_dataset) + for val_dataset in val_datasets: print(repr(val_dataset)) + data_loaders_val = [DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for val_dataset in val_datasets] + bestmetric = ("AVG_" if len(data_loaders_val)>1 else str(data_loaders_val[0].dataset)+'_')+args.bestmetric + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + # Training Loop + for epoch in range(args.start_epoch, args.epochs): + + if args.distributed: data_loader_train.sampler.set_epoch(epoch) + + # Train + epoch_start = time.time() + train_stats = train_one_epoch(model, criterion, metrics, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args) + epoch_time = time.time() - epoch_start + + if args.distributed: dist.barrier() + + # Validation (current naive implementation runs the validation on every gpu ... not smart ...) + if data_loaders_val is not None and args.eval_every > 0 and (epoch+1) % args.eval_every == 0: + val_epoch_start = time.time() + val_stats = validate_one_epoch(model, criterion, metrics, data_loaders_val, device, epoch, log_writer=log_writer, args=args) + val_epoch_time = time.time() - val_epoch_start + + val_best = val_stats[bestmetric] + + # Save best of all + if val_best <= best_so_far: + best_so_far = val_best + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='best') + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + **{f'val_{k}': v for k, v in val_stats.items()}} + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch,} + + if args.distributed: dist.barrier() + + # Save stuff + if args.output_dir and ((epoch+1) % args.save_every == 0 or epoch + 1 == args.epochs): + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='last') + + if args.output_dir: + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/croco/utils/misc.py b/imcui/third_party/mast3r/dust3r/croco/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..132e102a662c987dce5282633cb8730b0e0d5c2d --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/croco/utils/misc.py @@ -0,0 +1,463 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +import math +import json +from collections import defaultdict, deque +from pathlib import Path +import numpy as np + +import torch +import torch.distributed as dist +from torch import inf + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, max_iter=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable) + space_fmt = ':' + str(len(str(len_iterable))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for it,obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len_iterable - 1: + eta_seconds = iter_time.global_avg * (len_iterable - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len_iterable, eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + if max_iter and it >= max_iter: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len_iterable)) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + nodist = args.nodist if hasattr(args,'nodist') else False + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and not nodist: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self, enabled=True): + self._scaler = torch.cuda.amp.GradScaler(enabled=enabled) + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + + + +def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None): + output_dir = Path(args.output_dir) + if fname is None: fname = str(epoch) + checkpoint_path = output_dir / ('checkpoint-%s.pth' % fname) + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': loss_scaler.state_dict(), + 'args': args, + 'epoch': epoch, + } + if best_so_far is not None: to_save['best_so_far'] = best_so_far + print(f'>> Saving model to {checkpoint_path} ...') + save_on_master(to_save, checkpoint_path) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + args.start_epoch = 0 + best_so_far = None + if args.resume is not None: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + print("Resume checkpoint %s" % args.resume) + model_without_ddp.load_state_dict(checkpoint['model'], strict=False) + args.start_epoch = checkpoint['epoch'] + 1 + optimizer.load_state_dict(checkpoint['optimizer']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + if 'best_so_far' in checkpoint: + best_so_far = checkpoint['best_so_far'] + print(" & best_so_far={:g}".format(best_so_far)) + else: + print("") + print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end='') + return best_so_far + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + +def _replace(text, src, tgt, rm=''): + """ Advanced string replacement. + Given a text: + - replace all elements in src by the corresponding element in tgt + - remove all elements in rm + """ + if len(tgt) == 1: + tgt = tgt * len(src) + assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len" + for s,t in zip(src, tgt): + text = text.replace(s,t) + for c in rm: + text = text.replace(c,'') + return text + +def filename( obj ): + """ transform a python obj or cmd into a proper filename. + - \1 gets replaced by slash '/' + - \2 gets replaced by comma ',' + """ + if not isinstance(obj, str): + obj = repr(obj) + obj = str(obj).replace('()','') + obj = _replace(obj, '_,(*/\1\2','-__x%/,', rm=' )\'"') + assert all(len(s) < 256 for s in obj.split(os.sep)), 'filename too long (>256 characters):\n'+obj + return obj + +def _get_num_layer_for_vit(var_name, enc_depth, dec_depth): + if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("enc_blocks"): + layer_id = int(var_name.split('.')[1]) + return layer_id + 1 + elif var_name.startswith('decoder_embed') or var_name.startswith('enc_norm'): # part of the last black + return enc_depth + elif var_name.startswith('dec_blocks'): + layer_id = int(var_name.split('.')[1]) + return enc_depth + layer_id + 1 + elif var_name.startswith('dec_norm'): # part of the last block + return enc_depth + dec_depth + elif any(var_name.startswith(k) for k in ['head','prediction_head']): + return enc_depth + dec_depth + 1 + else: + raise NotImplementedError(var_name) + +def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]): + parameter_group_names = {} + parameter_group_vars = {} + enc_depth, dec_depth = None, None + # prepare layer decay values + assert layer_decay==1.0 or 0.> wrote {fpath}') + + print(f'Loaded {len(list_subscenes)} sub-scenes') + + # separate scenes + list_scenes = defaultdict(list) + for scene in list_subscenes: + scene, id = os.path.split(scene) + list_scenes[scene].append(id) + + list_scenes = list(list_scenes.items()) + print(f'from {len(list_scenes)} scenes in total') + + np.random.shuffle(list_scenes) + train_scenes = list_scenes[len(list_scenes)//10:] + val_scenes = list_scenes[:len(list_scenes)//10] + + def write_scene_list(scenes, n, fpath): + sub_scenes = [os.path.join(scene, id) for scene, ids in scenes for id in ids] + np.random.shuffle(sub_scenes) + + if len(sub_scenes) < n: + return + + with open(fpath, 'w') as f: + f.write('\n'.join(sub_scenes[:n])) + print(f'>> wrote {fpath}') + + for n in n_scenes: + write_scene_list(train_scenes, n, os.path.join(habitat_root, f'Habitat_{n}_scenes_train.txt')) + write_scene_list(val_scenes, n//10, os.path.join(habitat_root, f'Habitat_{n//10}_scenes_val.txt')) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--root", required=True) + parser.add_argument("--n_scenes", nargs='+', default=[1_000, 1_000_000], type=int) + + args = parser.parse_args() + find_all_scenes(args.root, args.n_scenes) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..4a31f1174a234b900ecaa76705fa271baf8a5669 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py @@ -0,0 +1,170 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Render environment maps from 3D meshes using the Habitat Sim simulator. +# -------------------------------------------------------- +import numpy as np +import habitat_sim +import math +from habitat_renderer import projections + +# OpenCV to habitat camera convention transformation +R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0) + +CUBEMAP_FACE_LABELS = ["left", "front", "right", "back", "up", "down"] +# Expressed while considering Habitat coordinates systems +CUBEMAP_FACE_ORIENTATIONS_ROTVEC = [ + [0, math.pi / 2, 0], # Left + [0, 0, 0], # Front + [0, - math.pi / 2, 0], # Right + [0, math.pi, 0], # Back + [math.pi / 2, 0, 0], # Up + [-math.pi / 2, 0, 0],] # Down + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + +class HabitatEnvironmentMapRenderer: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + render_equirectangular=False, + equirectangular_resolution=(512, 1024), + render_cubemap=False, + cubemap_resolution=(512, 512), + render_depth=False, + gpu_id=0): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.gpu_id = gpu_id + + self.render_equirectangular = render_equirectangular + self.equirectangular_resolution = equirectangular_resolution + self.equirectangular_projection = projections.EquirectangularProjection(*equirectangular_resolution) + # 3D unit ray associated to each pixel of the equirectangular map + equirectangular_rays = projections.get_projection_rays(self.equirectangular_projection) + # Not needed, but just in case. + equirectangular_rays /= np.linalg.norm(equirectangular_rays, axis=-1, keepdims=True) + # Depth map created by Habitat are produced by warping a cubemap, + # so the values do not correspond to distance to the center and need some scaling. + self.equirectangular_depth_scale_factors = 1.0 / np.max(np.abs(equirectangular_rays), axis=-1) + + self.render_cubemap = render_cubemap + self.cubemap_resolution = cubemap_resolution + + self.render_depth = render_depth + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly + if self.seed == None: + # Re-seed numpy generator + np.random.seed() + self.seed = np.random.randint(2**32-1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "": + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + sensor_specifications = [] + + # Add cubemaps + if self.render_cubemap: + for face_id, orientation in enumerate(CUBEMAP_FACE_ORIENTATIONS_ROTVEC): + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = f"color_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.cubemap_resolution + rgb_sensor_spec.hfov = 90 + rgb_sensor_spec.position = [0.0, 0.0, 0.0] + rgb_sensor_spec.orientation = orientation + sensor_specifications.append(rgb_sensor_spec) + + if self.render_depth: + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = f"depth_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.cubemap_resolution + depth_sensor_spec.hfov = 90 + depth_sensor_spec.position = [0.0, 0.0, 0.0] + depth_sensor_spec.orientation = orientation + sensor_specifications.append(depth_sensor_spec) + + # Add equirectangular map + if self.render_equirectangular: + rgb_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() + rgb_sensor_spec.uuid = "color_equirectangular" + rgb_sensor_spec.resolution = self.equirectangular_resolution + rgb_sensor_spec.position = [0.0, 0.0, 0.0] + sensor_specifications.append(rgb_sensor_spec) + + if self.render_depth: + depth_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() + depth_sensor_spec.uuid = "depth_equirectangular" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.equirectangular_resolution + depth_sensor_spec.position = [0.0, 0.0, 0.0] + depth_sensor_spec.orientation + sensor_specifications.append(depth_sensor_spec) + + agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=sensor_specifications) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + # Use pre-computed navmesh (the one generated automatically does some weird stuffs like going on top of the roof) + # See https://youtu.be/kunFMRJAu2U?t=1522 regarding navmeshes + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + # Check that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + # Try to compute a navmesh + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + # Check that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})") + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + if hasattr(self, 'sim'): + self.sim.close() + + def __del__(self): + self.close() + + def render_viewpoint(self, viewpoint_position): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + # agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + + try: + # Depth map values have been obtained using cubemap rendering internally, + # so they do not really correspond to distance to the viewpoint in practice + # and they need some scaling + viewpoint_observations["depth_equirectangular"] *= self.equirectangular_depth_scale_factors + except KeyError: + pass + + data = dict(observations=viewpoint_observations, position=viewpoint_position) + return data + + def up_direction(self): + return np.asarray(habitat_sim.geo.UP).tolist() + + def R_cam_to_world(self): + return R_OPENCV2HABITAT.tolist() diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b86238b44a5cdd7a2e30b9d64773c2388f9711c3 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py @@ -0,0 +1,93 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Generate pairs of crops from a dataset of environment maps. +# -------------------------------------------------------- +import os +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 +import collections +from habitat_renderer import projections, projections_conversions +from habitat_renderer.habitat_sim_envmaps_renderer import HabitatEnvironmentMapRenderer + +ViewpointData = collections.namedtuple("ViewpointData", ["colormap", "distancemap", "pointmap", "position"]) + +class HabitatMultiviewCrops: + def __init__(self, + scene, + navmesh, + scene_dataset_config_file, + equirectangular_resolution=(400, 800), + crop_resolution=(240, 320), + pixel_jittering_iterations=5, + jittering_noise_level=1.0): + self.crop_resolution = crop_resolution + + self.pixel_jittering_iterations = pixel_jittering_iterations + self.jittering_noise_level = jittering_noise_level + + # Instanciate the low resolution habitat sim renderer + self.lowres_envmap_renderer = HabitatEnvironmentMapRenderer(scene=scene, + navmesh=navmesh, + scene_dataset_config_file=scene_dataset_config_file, + equirectangular_resolution=equirectangular_resolution, + render_depth=True, + render_equirectangular=True) + self.R_cam_to_world = np.asarray(self.lowres_envmap_renderer.R_cam_to_world()) + self.up_direction = np.asarray(self.lowres_envmap_renderer.up_direction()) + + # Projection applied by each environment map + self.envmap_height, self.envmap_width = self.lowres_envmap_renderer.equirectangular_resolution + base_projection = projections.EquirectangularProjection(self.envmap_height, self.envmap_width) + self.envmap_projection = projections.RotatedProjection(base_projection, self.R_cam_to_world.T) + # 3D Rays map associated to each envmap + self.envmap_rays = projections.get_projection_rays(self.envmap_projection) + + def compute_pointmap(self, distancemap, position): + # Point cloud associated to each ray + return self.envmap_rays * distancemap[:, :, None] + position + + def render_viewpoint_data(self, position): + data = self.lowres_envmap_renderer.render_viewpoint(np.asarray(position)) + colormap = data['observations']['color_equirectangular'][..., :3] # Ignore the alpha channel + distancemap = data['observations']['depth_equirectangular'] + pointmap = self.compute_pointmap(distancemap, position) + return ViewpointData(colormap=colormap, distancemap=distancemap, pointmap=pointmap, position=position) + + def extract_cropped_camera(self, projection, color_image, distancemap, pointmap, voxelmap=None): + remapper = projections_conversions.RemapProjection(input_projection=self.envmap_projection, output_projection=projection, + pixel_jittering_iterations=self.pixel_jittering_iterations, jittering_noise_level=self.jittering_noise_level) + cropped_color_image = remapper.convert( + color_image, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False) + cropped_distancemap = remapper.convert( + distancemap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True) + cropped_pointmap = remapper.convert(pointmap, interpolation=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_WRAP, single_map=True) + cropped_voxelmap = (None if voxelmap is None else + remapper.convert(voxelmap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True)) + # Convert the distance map into a depth map + cropped_depthmap = np.asarray( + cropped_distancemap / np.linalg.norm(remapper.output_rays, axis=-1), dtype=cropped_distancemap.dtype) + + return cropped_color_image, cropped_depthmap, cropped_pointmap, cropped_voxelmap + +def perspective_projection_to_dict(persp_projection, position): + """ + Serialization-like function.""" + camera_params = dict(camera_intrinsics=projections.colmap_to_opencv_intrinsics(persp_projection.base_projection.K).tolist(), + size=(persp_projection.base_projection.width, persp_projection.base_projection.height), + R_cam2world=persp_projection.R_to_base_projection.T.tolist(), + t_cam2world=position) + return camera_params + + +def dict_to_perspective_projection(camera_params): + K = projections.opencv_to_colmap_intrinsics(np.asarray(camera_params["camera_intrinsics"])) + size = camera_params["size"] + R_cam2world = np.asarray(camera_params["R_cam2world"]) + projection = projections.PerspectiveProjection(K, height=size[1], width=size[0]) + projection = projections.RotatedProjection(projection, R_to_base_projection=R_cam2world.T) + position = camera_params["t_cam2world"] + return projection, position \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py new file mode 100644 index 0000000000000000000000000000000000000000..4db1f79d23e23a8ba144b4357c4d4daf10cf8fab --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py @@ -0,0 +1,151 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Various 3D/2D projection utils, useful to sample virtual cameras. +# -------------------------------------------------------- +import numpy as np + +class EquirectangularProjection: + """ + Convention for the central pixel of the equirectangular map similar to OpenCV perspective model: + +X from left to right + +Y from top to bottom + +Z going outside the camera + EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)) + """ + + def __init__(self, height, width): + self.height = height + self.width = width + self.u_scaling = (2 * np.pi) / self.width + self.v_scaling = np.pi / self.height + + def unproject(self, u, v): + """ + Args: + u, v: 2D coordinates + Returns: + unnormalized 3D rays. + """ + longitude = self.u_scaling * u - np.pi + minus_latitude = self.v_scaling * v - np.pi/2 + + cos_latitude = np.cos(minus_latitude) + x, z = np.sin(longitude) * cos_latitude, np.cos(longitude) * cos_latitude + y = np.sin(minus_latitude) + + rays = np.stack([x, y, z], axis=-1) + return rays + + def project(self, rays): + """ + Args: + rays: Bx3 array of 3D rays. + Returns: + u, v: tuple of 2D coordinates. + """ + rays = rays / np.linalg.norm(rays, axis=-1, keepdims=True) + x, y, z = [rays[..., i] for i in range(3)] + + longitude = np.arctan2(x, z) + minus_latitude = np.arcsin(y) + + u = (longitude + np.pi) * (1.0 / self.u_scaling) + v = (minus_latitude + np.pi/2) * (1.0 / self.v_scaling) + return u, v + + +class PerspectiveProjection: + """ + OpenCV convention: + World space: + +X from left to right + +Y from top to bottom + +Z going outside the camera + Pixel space: + +u from left to right + +v from top to bottom + EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)). + """ + + def __init__(self, K, height, width): + self.height = height + self.width = width + self.K = K + self.Kinv = np.linalg.inv(K) + + def project(self, rays): + uv_homogeneous = np.einsum("ik, ...k -> ...i", self.K, rays) + uv = uv_homogeneous[..., :2] / uv_homogeneous[..., 2, None] + return uv[..., 0], uv[..., 1] + + def unproject(self, u, v): + uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1) + rays = np.einsum("ik, ...k -> ...i", self.Kinv, uv_homogeneous) + return rays + + +class RotatedProjection: + def __init__(self, base_projection, R_to_base_projection): + self.base_projection = base_projection + self.R_to_base_projection = R_to_base_projection + + @property + def width(self): + return self.base_projection.width + + @property + def height(self): + return self.base_projection.height + + def project(self, rays): + if self.R_to_base_projection is not None: + rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection, rays) + return self.base_projection.project(rays) + + def unproject(self, u, v): + rays = self.base_projection.unproject(u, v) + if self.R_to_base_projection is not None: + rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection.T, rays) + return rays + +def get_projection_rays(projection, noise_level=0): + """ + Return a 2D map of 3D rays corresponding to the projection. + If noise_level > 0, add some jittering noise to these rays. + """ + grid_u, grid_v = np.meshgrid(0.5 + np.arange(projection.width), 0.5 + np.arange(projection.height)) + if noise_level > 0: + grid_u += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_u.shape), projection.width) + grid_v += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_v.shape), projection.height) + return projection.unproject(grid_u, grid_v) + +def compute_camera_intrinsics(height, width, hfov): + f = width/2 / np.tan(hfov/2 * np.pi/180) + cu, cv = width/2, height/2 + return f, cu, cv + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcfed4066bbac62fa4254ea6417bf429b098b75 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py @@ -0,0 +1,45 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Remap data from one projection to an other +# -------------------------------------------------------- +import numpy as np +import cv2 +from habitat_renderer import projections + +class RemapProjection: + def __init__(self, input_projection, output_projection, pixel_jittering_iterations=0, jittering_noise_level=0): + """ + Some naive random jittering can be introduced in the remapping to mitigate aliasing artecfacts. + """ + assert jittering_noise_level >= 0 + assert pixel_jittering_iterations >= 0 + + maps = [] + # Initial map + self.output_rays = projections.get_projection_rays(output_projection) + map_u, map_v = input_projection.project(self.output_rays) + map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) + maps.append((map_u, map_v)) + + for _ in range(pixel_jittering_iterations): + # Define multiple mappings using some coordinates jittering to mitigate aliasing effects + crop_rays = projections.get_projection_rays(output_projection, jittering_noise_level) + map_u, map_v = input_projection.project(crop_rays) + map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) + maps.append((map_u, map_v)) + self.maps = maps + + def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False): + remapped = [] + for map_u, map_v in self.maps: + res = cv2.remap(img, map_u, map_v, interpolation=interpolation, borderMode=borderMode) + remapped.append(res) + if single_map: + break + if len(remapped) == 1: + res = remapped[0] + else: + res = np.asarray(np.mean(remapped, axis=0), dtype=img.dtype) + return res diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/preprocess_habitat.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/preprocess_habitat.py new file mode 100644 index 0000000000000000000000000000000000000000..cacbe2467a8e9629c2472b0e05fc0cf8326367e2 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/habitat/preprocess_habitat.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# main executable for preprocessing habitat +# export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata" +# export SCENES_DIR="/path/to/habitat/data/scene_datasets/" +# export OUTPUT_DIR="data/habitat_processed" +# export PYTHONPATH=$(pwd) +# python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16 +# -------------------------------------------------------- +import os +import glob +import json +import os + +import PIL.Image +import json +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 +from habitat_renderer import multiview_crop_generator +from tqdm import tqdm + + +def preprocess_metadata(metadata_filename, + scenes_dir, + output_dir, + crop_resolution=[512, 512], + equirectangular_resolution=None, + fix_existing_dataset=False): + # Load data + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + if metadata["scene_dataset_config_file"] == "": + scene = os.path.join(scenes_dir, metadata["scene"]) + scene_dataset_config_file = "" + else: + scene = metadata["scene"] + scene_dataset_config_file = os.path.join(scenes_dir, metadata["scene_dataset_config_file"]) + navmesh = None + + # Use 4 times the crop size as resolution for rendering the environment map. + max_res = max(crop_resolution) + + if equirectangular_resolution == None: + # Use 4 times the crop size as resolution for rendering the environment map. + max_res = max(crop_resolution) + equirectangular_resolution = (4*max_res, 8*max_res) + + print("equirectangular_resolution:", equirectangular_resolution) + + if os.path.exists(output_dir) and not fix_existing_dataset: + raise FileExistsError(output_dir) + + # Lazy initialization + highres_dataset = None + + for batch_label, batch in tqdm(metadata["view_batches"].items()): + for view_label, view_params in batch.items(): + + assert view_params["size"] == crop_resolution + label = f"{batch_label}_{view_label}" + + output_camera_params_filename = os.path.join(output_dir, f"{label}_camera_params.json") + if fix_existing_dataset and os.path.isfile(output_camera_params_filename): + # Skip generation if we are fixing a dataset and the corresponding output file already exists + continue + + # Lazy initialization + if highres_dataset is None: + highres_dataset = multiview_crop_generator.HabitatMultiviewCrops(scene=scene, + navmesh=navmesh, + scene_dataset_config_file=scene_dataset_config_file, + equirectangular_resolution=equirectangular_resolution, + crop_resolution=crop_resolution,) + os.makedirs(output_dir, exist_ok=bool(fix_existing_dataset)) + + # Generate a higher resolution crop + original_projection, position = multiview_crop_generator.dict_to_perspective_projection(view_params) + # Render an envmap at the given position + viewpoint_data = highres_dataset.render_viewpoint_data(position) + + projection = original_projection + colormap, depthmap, pointmap, _ = highres_dataset.extract_cropped_camera( + projection, viewpoint_data.colormap, viewpoint_data.distancemap, viewpoint_data.pointmap) + + camera_params = multiview_crop_generator.perspective_projection_to_dict(projection, position) + + # Color image + PIL.Image.fromarray(colormap).save(os.path.join(output_dir, f"{label}.jpeg")) + # Depth image + cv2.imwrite(os.path.join(output_dir, f"{label}_depth.exr"), + depthmap, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + # Camera parameters + with open(output_camera_params_filename, "w") as f: + json.dump(camera_params, f) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_dir", required=True) + parser.add_argument("--scenes_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--metadata_filename", default="") + + args = parser.parse_args() + + if args.metadata_filename == "": + # Walk through the metadata dir to generate commandlines + for filename in glob.iglob(os.path.join(args.metadata_dir, "**/metadata.json"), recursive=True): + output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(filename), args.metadata_dir)) + if not os.path.exists(output_dir): + commandline = f"python {__file__} --metadata_filename={filename} --metadata_dir={args.metadata_dir} --scenes_dir={args.scenes_dir} --output_dir={output_dir}" + print(commandline) + else: + preprocess_metadata(metadata_filename=args.metadata_filename, + scenes_dir=args.scenes_dir, + output_dir=args.output_dir) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/path_to_root.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/path_to_root.py new file mode 100644 index 0000000000000000000000000000000000000000..6e076a17a408d0a9e043fbda2d73f1592e7cb71a --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/path_to_root.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R repo root import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../')) +# workaround for sibling import +sys.path.insert(0, DUST3R_REPO_PATH) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_arkitscenes.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_arkitscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbc103a82d646293e1d81f5132683e2b08cd879 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_arkitscenes.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the arkitscenes dataset. +# Usage: +# python3 datasets_preprocess/preprocess_arkitscenes.py --arkitscenes_dir /path/to/arkitscenes --precomputed_pairs /path/to/arkitscenes_pairs +# -------------------------------------------------------- +import os +import json +import os.path as osp +import decimal +import argparse +import math +from bisect import bisect_left +from PIL import Image +import numpy as np +import quaternion +from scipy import interpolate +import cv2 + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--arkitscenes_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/arkitscenes_processed') + return parser + + +def value_to_decimal(value, decimal_places): + decimal.getcontext().rounding = decimal.ROUND_HALF_UP # define rounding method + return decimal.Decimal(str(float(value))).quantize(decimal.Decimal('1e-{}'.format(decimal_places))) + + +def closest(value, sorted_list): + index = bisect_left(sorted_list, value) + if index == 0: + return sorted_list[0] + elif index == len(sorted_list): + return sorted_list[-1] + else: + value_before = sorted_list[index - 1] + value_after = sorted_list[index] + if value_after - value < value - value_before: + return value_after + else: + return value_before + + +def get_up_vectors(pose_device_to_world): + return np.matmul(pose_device_to_world, np.array([[0.0], [-1.0], [0.0], [0.0]])) + + +def get_right_vectors(pose_device_to_world): + return np.matmul(pose_device_to_world, np.array([[1.0], [0.0], [0.0], [0.0]])) + + +def read_traj(traj_path): + quaternions = [] + poses = [] + timestamps = [] + poses_p_to_w = [] + with open(traj_path) as f: + traj_lines = f.readlines() + for line in traj_lines: + tokens = line.split() + assert len(tokens) == 7 + traj_timestamp = float(tokens[0]) + + timestamps_decimal_value = value_to_decimal(traj_timestamp, 3) + timestamps.append(float(timestamps_decimal_value)) # for spline interpolation + + angle_axis = [float(tokens[1]), float(tokens[2]), float(tokens[3])] + r_w_to_p, _ = cv2.Rodrigues(np.asarray(angle_axis)) + t_w_to_p = np.asarray([float(tokens[4]), float(tokens[5]), float(tokens[6])]) + + pose_w_to_p = np.eye(4) + pose_w_to_p[:3, :3] = r_w_to_p + pose_w_to_p[:3, 3] = t_w_to_p + + pose_p_to_w = np.linalg.inv(pose_w_to_p) + + r_p_to_w_as_quat = quaternion.from_rotation_matrix(pose_p_to_w[:3, :3]) + t_p_to_w = pose_p_to_w[:3, 3] + poses_p_to_w.append(pose_p_to_w) + poses.append(t_p_to_w) + quaternions.append(r_p_to_w_as_quat) + return timestamps, poses, quaternions, poses_p_to_w + + +def main(rootdir, pairsdir, outdir): + os.makedirs(outdir, exist_ok=True) + + subdirs = ['Test', 'Training'] + for subdir in subdirs: + if not osp.isdir(osp.join(rootdir, subdir)): + continue + # STEP 1: list all scenes + outsubdir = osp.join(outdir, subdir) + os.makedirs(outsubdir, exist_ok=True) + listfile = osp.join(pairsdir, subdir, 'scene_list.json') + with open(listfile, 'r') as f: + scene_dirs = json.load(f) + + valid_scenes = [] + for scene_subdir in scene_dirs: + out_scene_subdir = osp.join(outsubdir, scene_subdir) + os.makedirs(out_scene_subdir, exist_ok=True) + + scene_dir = osp.join(rootdir, subdir, scene_subdir) + depth_dir = osp.join(scene_dir, 'lowres_depth') + rgb_dir = osp.join(scene_dir, 'vga_wide') + intrinsics_dir = osp.join(scene_dir, 'vga_wide_intrinsics') + traj_path = osp.join(scene_dir, 'lowres_wide.traj') + + # STEP 2: read selected_pairs.npz + selected_pairs_path = osp.join(pairsdir, subdir, scene_subdir, 'selected_pairs.npz') + selected_npz = np.load(selected_pairs_path) + selection, pairs = selected_npz['selection'], selected_npz['pairs'] + selected_sky_direction_scene = str(selected_npz['sky_direction_scene'][0]) + if len(selection) == 0 or len(pairs) == 0: + # not a valid scene + continue + valid_scenes.append(scene_subdir) + + # STEP 3: parse the scene and export the list of valid (K, pose, rgb, depth) and convert images + scene_metadata_path = osp.join(out_scene_subdir, 'scene_metadata.npz') + if osp.isfile(scene_metadata_path): + continue + else: + print(f'parsing {scene_subdir}') + # loads traj + timestamps, poses, quaternions, poses_cam_to_world = read_traj(traj_path) + + poses = np.array(poses) + quaternions = np.array(quaternions, dtype=np.quaternion) + quaternions = quaternion.unflip_rotors(quaternions) + timestamps = np.array(timestamps) + + selected_images = [(basename, basename.split(".png")[0].split("_")[1]) for basename in selection] + timestamps_selected = [float(frame_id) for _, frame_id in selected_images] + + sky_direction_scene, trajectories, intrinsics, images = convert_scene_metadata(scene_subdir, + intrinsics_dir, + timestamps, + quaternions, + poses, + poses_cam_to_world, + selected_images, + timestamps_selected) + assert selected_sky_direction_scene == sky_direction_scene + + os.makedirs(os.path.join(out_scene_subdir, 'vga_wide'), exist_ok=True) + os.makedirs(os.path.join(out_scene_subdir, 'lowres_depth'), exist_ok=True) + assert isinstance(sky_direction_scene, str) + for basename in images: + img_out = os.path.join(out_scene_subdir, 'vga_wide', basename.replace('.png', '.jpg')) + depth_out = os.path.join(out_scene_subdir, 'lowres_depth', basename) + if osp.isfile(img_out) and osp.isfile(depth_out): + continue + + vga_wide_path = osp.join(rgb_dir, basename) + depth_path = osp.join(depth_dir, basename) + + img = Image.open(vga_wide_path) + depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) + + # rotate the image + if sky_direction_scene == 'RIGHT': + try: + img = img.transpose(Image.Transpose.ROTATE_90) + except Exception: + img = img.transpose(Image.ROTATE_90) + depth = cv2.rotate(depth, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif sky_direction_scene == 'LEFT': + try: + img = img.transpose(Image.Transpose.ROTATE_270) + except Exception: + img = img.transpose(Image.ROTATE_270) + depth = cv2.rotate(depth, cv2.ROTATE_90_CLOCKWISE) + elif sky_direction_scene == 'DOWN': + try: + img = img.transpose(Image.Transpose.ROTATE_180) + except Exception: + img = img.transpose(Image.ROTATE_180) + depth = cv2.rotate(depth, cv2.ROTATE_180) + + W, H = img.size + if not osp.isfile(img_out): + img.save(img_out) + + depth = cv2.resize(depth, (W, H), interpolation=cv2.INTER_NEAREST_EXACT) + if not osp.isfile(depth_out): # avoid destroying the base dataset when you mess up the paths + cv2.imwrite(depth_out, depth) + + # save at the end + np.savez(scene_metadata_path, + trajectories=trajectories, + intrinsics=intrinsics, + images=images, + pairs=pairs) + + outlistfile = osp.join(outsubdir, 'scene_list.json') + with open(outlistfile, 'w') as f: + json.dump(valid_scenes, f) + + # STEP 5: concat all scene_metadata.npz into a single file + scene_data = {} + for scene_subdir in valid_scenes: + scene_metadata_path = osp.join(outsubdir, scene_subdir, 'scene_metadata.npz') + with np.load(scene_metadata_path) as data: + trajectories = data['trajectories'] + intrinsics = data['intrinsics'] + images = data['images'] + pairs = data['pairs'] + scene_data[scene_subdir] = {'trajectories': trajectories, + 'intrinsics': intrinsics, + 'images': images, + 'pairs': pairs} + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + pairs = [] + for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): + num_imgs = data['images'].shape[0] + img_pairs = data['pairs'] + + scenes.append(scene_subdir) + sceneids.extend([scene_idx] * num_imgs) + + images.append(data['images']) + + K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) + K[:, 0, 0] = [fx for _, _, fx, _, _, _ in data['intrinsics']] + K[:, 1, 1] = [fy for _, _, _, fy, _, _ in data['intrinsics']] + K[:, 0, 2] = [hw for _, _, _, _, hw, _ in data['intrinsics']] + K[:, 1, 2] = [hh for _, _, _, _, _, hh in data['intrinsics']] + + intrinsics.append(K) + trajectories.append(data['trajectories']) + + # offset pairs + img_pairs[:, 0:2] += offset + pairs.append(img_pairs) + counts.append(offset) + + offset += num_imgs + + images = np.concatenate(images, axis=0) + intrinsics = np.concatenate(intrinsics, axis=0) + trajectories = np.concatenate(trajectories, axis=0) + pairs = np.concatenate(pairs, axis=0) + np.savez(osp.join(outsubdir, 'all_metadata.npz'), + counts=counts, + scenes=scenes, + sceneids=sceneids, + images=images, + intrinsics=intrinsics, + trajectories=trajectories, + pairs=pairs) + + +def convert_scene_metadata(scene_subdir, intrinsics_dir, + timestamps, quaternions, poses, poses_cam_to_world, + selected_images, timestamps_selected): + # find scene orientation + sky_direction_scene, rotated_to_cam = find_scene_orientation(poses_cam_to_world) + + # find/compute pose for selected timestamps + # most images have a valid timestamp / exact pose associated + timestamps_selected = np.array(timestamps_selected) + spline = interpolate.interp1d(timestamps, poses, kind='linear', axis=0) + interpolated_rotations = quaternion.squad(quaternions, timestamps, timestamps_selected) + interpolated_positions = spline(timestamps_selected) + + trajectories = [] + intrinsics = [] + images = [] + for i, (basename, frame_id) in enumerate(selected_images): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{frame_id}.pincam") + if not osp.exists(intrinsic_fn): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) - 0.001:.3f}.pincam") + if not osp.exists(intrinsic_fn): + intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) + 0.001:.3f}.pincam") + assert osp.exists(intrinsic_fn) + w, h, fx, fy, hw, hh = np.loadtxt(intrinsic_fn) # PINHOLE + + pose = np.eye(4) + pose[:3, :3] = quaternion.as_rotation_matrix(interpolated_rotations[i]) + pose[:3, 3] = interpolated_positions[i] + + images.append(basename) + if sky_direction_scene == 'RIGHT' or sky_direction_scene == 'LEFT': + intrinsics.append([h, w, fy, fx, hh, hw]) # swapped intrinsics + else: + intrinsics.append([w, h, fx, fy, hw, hh]) + trajectories.append(pose @ rotated_to_cam) # pose_cam_to_world @ rotated_to_cam = rotated(cam) to world + + return sky_direction_scene, trajectories, intrinsics, images + + +def find_scene_orientation(poses_cam_to_world): + if len(poses_cam_to_world) > 0: + up_vector = sum(get_up_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) + right_vector = sum(get_right_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) + up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) + else: + up_vector = np.array([[0.0], [-1.0], [0.0], [0.0]]) + right_vector = np.array([[1.0], [0.0], [0.0], [0.0]]) + up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) + + # value between 0, 180 + device_up_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), + up_vector), -1.0, 1.0)).item() * 180.0 / np.pi + device_right_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), + right_vector), -1.0, 1.0)).item() * 180.0 / np.pi + + up_closest_to_90 = abs(device_up_to_world_up_angle - 90.0) < abs(device_right_to_world_up_angle - 90.0) + if up_closest_to_90: + assert abs(device_up_to_world_up_angle - 90.0) < 45.0 + # LEFT + if device_right_to_world_up_angle > 90.0: + sky_direction_scene = 'LEFT' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi / 2.0]) + else: + # note that in metadata.csv RIGHT does not exist, but again it's not accurate... + # well, turns out there are scenes oriented like this + # for example Training/41124801 + sky_direction_scene = 'RIGHT' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, -math.pi / 2.0]) + else: + # right is close to 90 + assert abs(device_right_to_world_up_angle - 90.0) < 45.0 + if device_up_to_world_up_angle > 90.0: + sky_direction_scene = 'DOWN' + cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi]) + else: + sky_direction_scene = 'UP' + cam_to_rotated_q = quaternion.quaternion(1, 0, 0, 0) + cam_to_rotated = np.eye(4) + cam_to_rotated[:3, :3] = quaternion.as_rotation_matrix(cam_to_rotated_q) + rotated_to_cam = np.linalg.inv(cam_to_rotated) + return sky_direction_scene, rotated_to_cam + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.arkitscenes_dir, args.precomputed_pairs, args.output_dir) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_blendedMVS.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_blendedMVS.py new file mode 100644 index 0000000000000000000000000000000000000000..d22793793c1219ebb1b3ba8eff51226c2b13f657 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_blendedMVS.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the BlendedMVS dataset +# dataset at https://github.com/YoYo000/BlendedMVS +# 1) Download BlendedMVS.zip +# 2) Download BlendedMVS+.zip +# 3) Download BlendedMVS++.zip +# 4) Unzip everything in the same /path/to/tmp/blendedMVS/ directory +# 5) python datasets_preprocess/preprocess_blendedMVS.py --blendedmvs_dir /path/to/tmp/blendedMVS/ +# -------------------------------------------------------- +import os +import os.path as osp +import re +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--blendedmvs_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/blendedmvs_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + print('>> Listing all sequences') + sequences = [f for f in os.listdir(db_root) if len(f) == 24] + # should find 502 scenes + assert sequences, f'did not found any sequences at {db_root}' + print(f' (found {len(sequences)} sequences)') + + for i, seq in enumerate(tqdm(sequences)): + out_dir = osp.join(output_dir, seq) + os.makedirs(out_dir, exist_ok=True) + + # generate the crops + root = osp.join(db_root, seq) + cam_dir = osp.join(root, 'cams') + func_args = [(root, f[:-8], out_dir) for f in os.listdir(cam_dir) if not f.startswith('pair')] + parallel_threads(load_crop_and_save, func_args, star_args=True, leave=False) + + # verify that all pairs are there + pairs = np.load(pairs_path) + for seqh, seql, img1, img2, score in tqdm(pairs): + for view_index in [img1, img2]: + impath = osp.join(output_dir, f"{seqh:08x}{seql:016x}", f"{view_index:08n}.jpg") + assert osp.isfile(impath), f'missing image at {impath=}' + + print(f'>> Done, saved everything in {output_dir}/') + + +def load_crop_and_save(root, img, out_dir): + if osp.isfile(osp.join(out_dir, img + '.npz')): + return # already done + + # load everything + intrinsics_in, R_camin2world, t_camin2world = _load_pose(osp.join(root, 'cams', img + '_cam.txt')) + color_image_in = cv2.cvtColor(cv2.imread(osp.join(root, 'blended_images', img + + '.jpg'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + depthmap_in = load_pfm_file(osp.join(root, 'rendered_depth_maps', img + '.pfm')) + + # do the crop + H, W = color_image_in.shape[:2] + assert H * 4 == W * 3 + image, depthmap, intrinsics_out, R_in2out = _crop_image(intrinsics_in, color_image_in, depthmap_in, (512, 384)) + + # write everything + image.save(osp.join(out_dir, img + '.jpg'), quality=80) + cv2.imwrite(osp.join(out_dir, img + '.exr'), depthmap) + + # New camera parameters + R_camout2world = R_camin2world @ R_in2out.T + t_camout2world = t_camin2world + np.savez(osp.join(out_dir, img + '.npz'), intrinsics=intrinsics_out, + R_cam2world=R_camout2world, t_cam2world=t_camout2world) + + +def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(800, 800)): + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + color_image_in, depthmap_in, intrinsics_in, resolution_out) + R_in2out = np.eye(3) + return image, depthmap, intrinsics_out, R_in2out + + +def _load_pose(path, ret_44=False): + f = open(path) + RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32) + assert RT.shape == (4, 4) + RT = np.linalg.inv(RT) # world2cam to cam2world + + K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32) + assert K.shape == (3, 3) + + if ret_44: + return K, RT + return K, RT[:3, :3], RT[:3, 3] # , depth_uint8_to_f32 + + +def load_pfm_file(file_path): + with open(file_path, 'rb') as file: + header = file.readline().decode('UTF-8').strip() + + if header == 'PF': + is_color = True + elif header == 'Pf': + is_color = False + else: + raise ValueError('The provided file is not a valid PFM file.') + + dimensions = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('UTF-8')) + if dimensions: + img_width, img_height = map(int, dimensions.groups()) + else: + raise ValueError('Invalid PFM header format.') + + endian_scale = float(file.readline().decode('UTF-8').strip()) + if endian_scale < 0: + dtype = '= img_size * 3/4, and max dimension will be >= img_size")) + return parser + + +def convert_ndc_to_pinhole(focal_length, principal_point, image_size): + focal_length = np.array(focal_length) + principal_point = np.array(principal_point) + image_size_wh = np.array([image_size[1], image_size[0]]) + half_image_size = image_size_wh / 2 + rescale = half_image_size.min() + principal_point_px = half_image_size - principal_point * rescale + focal_length_px = focal_length * rescale + fx, fy = focal_length_px[0], focal_length_px[1] + cx, cy = principal_point_px[0], principal_point_px[1] + K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32) + return K + + +def opencv_from_cameras_projection(R, T, focal, p0, image_size): + R = torch.from_numpy(R)[None, :, :] + T = torch.from_numpy(T)[None, :] + focal = torch.from_numpy(focal)[None, :] + p0 = torch.from_numpy(p0)[None, :] + image_size = torch.from_numpy(image_size)[None, :] + + R_pytorch3d = R.clone() + T_pytorch3d = T.clone() + focal_pytorch3d = focal + p0_pytorch3d = p0 + T_pytorch3d[:, :2] *= -1 + R_pytorch3d[:, :, :2] *= -1 + tvec = T_pytorch3d + R = R_pytorch3d.permute(0, 2, 1) + + # Retype the image_size correctly and flip to width, height. + image_size_wh = image_size.to(R).flip(dims=(1,)) + + # NDC to screen conversion. + scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 + scale = scale.expand(-1, 2) + c0 = image_size_wh / 2.0 + + principal_point = -p0_pytorch3d * scale + c0 + focal_length = focal_pytorch3d * scale + + camera_matrix = torch.zeros_like(R) + camera_matrix[:, :2, 2] = principal_point + camera_matrix[:, 2, 2] = 1.0 + camera_matrix[:, 0, 0] = focal_length[:, 0] + camera_matrix[:, 1, 1] = focal_length[:, 1] + return R[0], tvec[0], camera_matrix[0] + + +def get_set_list(category_dir, split, is_single_sequence_subset=False): + listfiles = os.listdir(osp.join(category_dir, "set_lists")) + if is_single_sequence_subset: + # not all objects have manyview_dev + subset_list_files = [f for f in listfiles if "manyview_dev" in f] + else: + subset_list_files = [f for f in listfiles if f"fewview_train" in f] + + sequences_all = [] + for subset_list_file in subset_list_files: + with open(osp.join(category_dir, "set_lists", subset_list_file)) as f: + subset_lists_data = json.load(f) + sequences_all.extend(subset_lists_data[split]) + + return sequences_all + + +def prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object, + seed, is_single_sequence_subset=False): + random.seed(seed) + category_dir = osp.join(co3d_dir, category) + category_output_dir = osp.join(output_dir, category) + sequences_all = get_set_list(category_dir, split, is_single_sequence_subset) + sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all)) + + frame_file = osp.join(category_dir, "frame_annotations.jgz") + sequence_file = osp.join(category_dir, "sequence_annotations.jgz") + + with gzip.open(frame_file, "r") as fin: + frame_data = json.loads(fin.read()) + with gzip.open(sequence_file, "r") as fin: + sequence_data = json.loads(fin.read()) + + frame_data_processed = {} + for f_data in frame_data: + sequence_name = f_data["sequence_name"] + frame_data_processed.setdefault(sequence_name, {})[f_data["frame_number"]] = f_data + + good_quality_sequences = set() + for seq_data in sequence_data: + if seq_data["viewpoint_quality_score"] > min_quality: + good_quality_sequences.add(seq_data["sequence_name"]) + + sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences] + if len(sequences_numbers) < max_num_sequences_per_object: + selected_sequences_numbers = sequences_numbers + else: + selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object) + + selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers} + sequences_all = [(seq_name, frame_number, filepath) + for seq_name, frame_number, filepath in sequences_all + if seq_name in selected_sequences_numbers_dict] + + for seq_name, frame_number, filepath in tqdm(sequences_all): + frame_idx = int(filepath.split('/')[-1][5:-4]) + selected_sequences_numbers_dict[seq_name].append(frame_idx) + mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") + frame_data = frame_data_processed[seq_name][frame_number] + focal_length = frame_data["viewpoint"]["focal_length"] + principal_point = frame_data["viewpoint"]["principal_point"] + image_size = frame_data["image"]["size"] + K = convert_ndc_to_pinhole(focal_length, principal_point, image_size) + R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data["viewpoint"]["R"]), + np.array(frame_data["viewpoint"]["T"]), + np.array(focal_length), + np.array(principal_point), + np.array(image_size)) + + frame_data = frame_data_processed[seq_name][frame_number] + depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"]) + assert frame_data["depth"]["scale_adjustment"] == 1.0 + image_path = os.path.join(co3d_dir, filepath) + mask_path_full = os.path.join(co3d_dir, mask_path) + + input_rgb_image = PIL.Image.open(image_path).convert('RGB') + input_mask = plt.imread(mask_path_full) + + with PIL.Image.open(depth_path) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + input_depthmap = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0]))) + depth_mask = np.stack((input_depthmap, input_mask), axis=-1) + H, W = input_depthmap.shape + + camera_intrinsics = camera_intrinsics.numpy() + cx, cy = camera_intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( + input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) + + # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 + scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + if max(output_resolution) < img_size: + # let's put the max dimension to img_size + scale_final = (img_size / max(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( + input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) + input_depthmap = depth_mask[:, :, 0] + input_mask = depth_mask[:, :, 1] + + # generate and adjust camera pose + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = R + camera_pose[:3, 3] = tvec + camera_pose = np.linalg.inv(camera_pose) + + # save crop images and depth, metadata + save_img_path = os.path.join(output_dir, filepath) + save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"]) + save_mask_path = os.path.join(output_dir, mask_path) + os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) + + input_rgb_image.save(save_img_path) + scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16) + cv2.imwrite(save_depth_path, scaled_depth_map) + cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) + + save_meta_path = save_img_path.replace('jpg', 'npz') + np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, + camera_pose=camera_pose, maximum_depth=np.max(input_depthmap)) + + return selected_sequences_numbers_dict + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + assert args.co3d_dir != args.output_dir + if args.category is None: + if args.single_sequence_subset: + categories = SINGLE_SEQUENCE_CATEGORIES + else: + categories = CATEGORIES + else: + categories = [args.category] + os.makedirs(args.output_dir, exist_ok=True) + + for split in ['train', 'test']: + selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(selected_sequences_path): + continue + + all_selected_sequences = {} + for category in categories: + category_output_dir = osp.join(args.output_dir, category) + os.makedirs(category_output_dir, exist_ok=True) + category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(category_selected_sequences_path): + with open(category_selected_sequences_path, 'r') as fid: + category_selected_sequences = json.load(fid) + else: + print(f"Processing {split} - category = {category}") + category_selected_sequences = prepare_sequences( + category=category, + co3d_dir=args.co3d_dir, + output_dir=args.output_dir, + img_size=args.img_size, + split=split, + min_quality=args.min_quality, + max_num_sequences_per_object=args.num_sequences_per_object, + seed=args.seed + CATEGORIES_IDX[category], + is_single_sequence_subset=args.single_sequence_subset + ) + with open(category_selected_sequences_path, 'w') as file: + json.dump(category_selected_sequences, file) + + all_selected_sequences[category] = category_selected_sequences + with open(selected_sequences_path, 'w') as file: + json.dump(all_selected_sequences, file) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_megadepth.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..b07c0c5dff0cfd828f9ce4fd204cf2eaa22487f1 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_megadepth.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the MegaDepth dataset +# dataset at https://www.cs.cornell.edu/projects/megadepth/ +# -------------------------------------------------------- +import os +import os.path as osp +import collections +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 +import h5py + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--megadepth_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/megadepth_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + os.makedirs(output_dir, exist_ok=True) + + # load all pairs + data = np.load(pairs_path, allow_pickle=True) + scenes = data['scenes'] + images = data['images'] + pairs = data['pairs'] + + # enumerate all unique images + todo = collections.defaultdict(set) + for scene, im1, im2, score in pairs: + todo[scene].add(im1) + todo[scene].add(im2) + + # for each scene, load intrinsics and then parallel crops + for scene, im_idxs in tqdm(todo.items(), desc='Overall'): + scene, subscene = scenes[scene].split() + out_dir = osp.join(output_dir, scene, subscene) + os.makedirs(out_dir, exist_ok=True) + + # load all camera params + _, pose_w2cam, intrinsics = _load_kpts_and_poses(db_root, scene, subscene, intrinsics=True) + + in_dir = osp.join(db_root, scene, 'dense' + subscene) + args = [(in_dir, img, intrinsics[img], pose_w2cam[img], out_dir) + for img in [images[im_id] for im_id in im_idxs]] + parallel_threads(resize_one_image, args, star_args=True, front_num=0, leave=False, desc=f'{scene}/{subscene}') + + # save pairs + print('Done! prepared all pairs in', output_dir) + + +def resize_one_image(root, tag, K_pre_rectif, pose_w2cam, out_dir): + if osp.isfile(osp.join(out_dir, tag + '.npz')): + return + + # load image + img = cv2.cvtColor(cv2.imread(osp.join(root, 'imgs', tag), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + H, W = img.shape[:2] + + # load depth + with h5py.File(osp.join(root, 'depths', osp.splitext(tag)[0] + '.h5'), 'r') as hd5: + depthmap = np.asarray(hd5['depth']) + + # rectify = undistort the intrinsics + imsize_pre, K_pre, distortion = K_pre_rectif + imsize_post = img.shape[1::-1] + K_post = cv2.getOptimalNewCameraMatrix(K_pre, distortion, imsize_pre, alpha=0, + newImgSize=imsize_post, centerPrincipalPoint=True)[0] + + # downscale + img_out, depthmap_out, intrinsics_out, R_in2out = _downscale_image(K_post, img, depthmap, resolution_out=(800, 600)) + + # write everything + img_out.save(osp.join(out_dir, tag + '.jpg'), quality=90) + cv2.imwrite(osp.join(out_dir, tag + '.exr'), depthmap_out) + + camout2world = np.linalg.inv(pose_w2cam) + camout2world[:3, :3] = camout2world[:3, :3] @ R_in2out.T + np.savez(osp.join(out_dir, tag + '.npz'), intrinsics=intrinsics_out, cam2world=camout2world) + + +def _downscale_image(camera_intrinsics, image, depthmap, resolution_out=(512, 384)): + H, W = image.shape[:2] + resolution_out = sorted(resolution_out)[::+1 if W < H else -1] + + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + image, depthmap, camera_intrinsics, resolution_out, force=False) + R_in2out = np.eye(3) + + return image, depthmap, intrinsics_out, R_in2out + + +def _load_kpts_and_poses(root, scene_id, subscene, z_only=False, intrinsics=False): + if intrinsics: + with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'cameras.txt'), 'r') as f: + raw = f.readlines()[3:] # skip the header + + camera_intrinsics = {} + for camera in raw: + camera = camera.split(' ') + width, height, focal, cx, cy, k0 = [float(elem) for elem in camera[2:]] + K = np.eye(3) + K[0, 0] = focal + K[1, 1] = focal + K[0, 2] = cx + K[1, 2] = cy + camera_intrinsics[int(camera[0])] = ((int(width), int(height)), K, (k0, 0, 0, 0)) + + with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'images.txt'), 'r') as f: + raw = f.read().splitlines()[4:] # skip the header + + extract_pose = colmap_raw_pose_to_principal_axis if z_only else colmap_raw_pose_to_RT + + poses = {} + points3D_idxs = {} + camera = [] + + for image, points in zip(raw[:: 2], raw[1:: 2]): + image = image.split(' ') + points = points.split(' ') + + image_id = image[-1] + camera.append(int(image[-2])) + + # find the principal axis + raw_pose = [float(elem) for elem in image[1: -2]] + poses[image_id] = extract_pose(raw_pose) + + current_points3D_idxs = {int(i) for i in points[2:: 3] if i != '-1'} + assert -1 not in current_points3D_idxs, bb() + points3D_idxs[image_id] = current_points3D_idxs + + if intrinsics: + image_intrinsics = {im_id: camera_intrinsics[cam] for im_id, cam in zip(poses, camera)} + return points3D_idxs, poses, image_intrinsics + else: + return points3D_idxs, poses + + +def colmap_raw_pose_to_principal_axis(image_pose): + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + z_axis = np.float32([ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ]) + return z_axis + + +def colmap_raw_pose_to_RT(image_pose): + qvec = image_pose[: 4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array([ + [ + 1 - 2 * y * y - 2 * z * z, + 2 * x * y - 2 * z * w, + 2 * x * z + 2 * y * w + ], + [ + 2 * x * y + 2 * z * w, + 1 - 2 * x * x - 2 * z * z, + 2 * y * z - 2 * x * w + ], + [ + 2 * x * z - 2 * y * w, + 2 * y * z + 2 * x * w, + 1 - 2 * x * x - 2 * y * y + ] + ]) + # principal_axis.append(R[2, :]) + t = image_pose[4: 7] + # World-to-Camera pose + current_pose = np.eye(4) + current_pose[: 3, : 3] = R + current_pose[: 3, 3] = t + return current_pose + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.megadepth_dir, args.precomputed_pairs, args.output_dir) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_scannetpp.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..34e26dc9474df16cf0736f71248d01b7853d4786 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_scannetpp.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the scannet++ dataset. +# Usage: +# python3 datasets_preprocess/preprocess_scannetpp.py --scannetpp_dir /path/to/scannetpp --precomputed_pairs /path/to/scannetpp_pairs --pyopengl-platform egl +# -------------------------------------------------------- +import os +import argparse +import os.path as osp +import re +from tqdm import tqdm +import json +from scipy.spatial.transform import Rotation +import pyrender +import trimesh +import trimesh.exchange.ply +import numpy as np +import cv2 +import PIL.Image as Image + +from dust3r.datasets.utils.cropping import rescale_image_depthmap +import dust3r.utils.geometry as geometry + +inv = np.linalg.inv +norm = np.linalg.norm +REGEXPR_DSLR = re.compile(r'^DSC(?P\d+).JPG$') +REGEXPR_IPHONE = re.compile(r'frame_(?P\d+).jpg$') + +DEBUG_VIZ = None # 'iou' +if DEBUG_VIZ is not None: + import matplotlib.pyplot as plt # noqa + + +OPENGL_TO_OPENCV = np.float32([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--scannetpp_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/scannetpp_processed') + parser.add_argument('--target_resolution', default=920, type=int, help="images resolution") + parser.add_argument('--pyopengl-platform', type=str, default='', help='PyOpenGL env variable') + return parser + + +def pose_from_qwxyz_txyz(elems): + qw, qx, qy, qz, tx, ty, tz = map(float, elems) + pose = np.eye(4) + pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() + pose[:3, 3] = (tx, ty, tz) + return np.linalg.inv(pose) # returns cam2world + + +def get_frame_number(name, cam_type='dslr'): + if cam_type == 'dslr': + regex_expr = REGEXPR_DSLR + elif cam_type == 'iphone': + regex_expr = REGEXPR_IPHONE + else: + raise NotImplementedError(f'wrong {cam_type=} for get_frame_number') + matches = re.match(regex_expr, name) + return matches['frameid'] + + +def load_sfm(sfm_dir, cam_type='dslr'): + # load cameras + with open(osp.join(sfm_dir, 'cameras.txt'), 'r') as f: + raw = f.read().splitlines()[3:] # skip header + + intrinsics = {} + for camera in tqdm(raw, position=1, leave=False): + camera = camera.split(' ') + intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]] + + # load images + with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + img_idx = {} + img_infos = {} + for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2, position=1, leave=False): + image = image.split(' ') + points = points.split(' ') + + idx = image[0] + img_name = image[-1] + assert img_name not in img_idx, 'duplicate db image: ' + img_name + img_idx[img_name] = idx # register image name + + current_points2D = {int(i): (float(x), float(y)) + for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} + img_infos[idx] = dict(intrinsics=intrinsics[int(image[-2])], + path=img_name, + frame_id=get_frame_number(img_name, cam_type), + cam_to_world=pose_from_qwxyz_txyz(image[1: -2]), + sparse_pts2d=current_points2D) + + # load 3D points + with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + points3D = {} + observations = {idx: [] for idx in img_infos.keys()} + for point in tqdm(raw, position=1, leave=False): + point = point.split() + point_3d_idx = int(point[0]) + points3D[point_3d_idx] = tuple(map(float, point[1:4])) + if len(point) > 8: + for idx, point_2d_idx in zip(point[8::2], point[9::2]): + observations[idx].append((point_3d_idx, int(point_2d_idx))) + + return img_idx, img_infos, points3D, observations + + +def subsample_img_infos(img_infos, num_images, allowed_name_subset=None): + img_infos_val = [(idx, val) for idx, val in img_infos.items()] + if allowed_name_subset is not None: + img_infos_val = [(idx, val) for idx, val in img_infos_val if val['path'] in allowed_name_subset] + + if len(img_infos_val) > num_images: + img_infos_val = sorted(img_infos_val, key=lambda x: x[1]['frame_id']) + kept_idx = np.round(np.linspace(0, len(img_infos_val) - 1, num_images)).astype(int).tolist() + img_infos_val = [img_infos_val[idx] for idx in kept_idx] + return {idx: val for idx, val in img_infos_val} + + +def undistort_images(intrinsics, rgb, mask): + camera_type = intrinsics[0] + + width = int(intrinsics[1]) + height = int(intrinsics[2]) + fx = intrinsics[3] + fy = intrinsics[4] + cx = intrinsics[5] + cy = intrinsics[6] + distortion = np.array(intrinsics[7:]) + + K = np.zeros([3, 3]) + K[0, 0] = fx + K[0, 2] = cx + K[1, 1] = fy + K[1, 2] = cy + K[2, 2] = 1 + + K = geometry.colmap_to_opencv_intrinsics(K) + if camera_type == "OPENCV_FISHEYE": + assert len(distortion) == 4 + + new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( + K, + distortion, + (width, height), + np.eye(3), + balance=0.0, + ) + # Make the cx and cy to be the center of the image + new_K[0, 2] = width / 2.0 + new_K[1, 2] = height / 2.0 + + map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) + else: + new_K, _ = cv2.getOptimalNewCameraMatrix(K, distortion, (width, height), 1, (width, height), True) + map1, map2 = cv2.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) + + undistorted_image = cv2.remap(rgb, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) + undistorted_mask = cv2.remap(mask, map1, map2, interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, borderValue=255) + K = geometry.opencv_to_colmap_intrinsics(K) + return width, height, new_K, undistorted_image, undistorted_mask + + +def process_scenes(root, pairsdir, output_dir, target_resolution): + os.makedirs(output_dir, exist_ok=True) + + # default values from + # https://github.com/scannetpp/scannetpp/blob/main/common/configs/render.yml + znear = 0.05 + zfar = 20.0 + + listfile = osp.join(pairsdir, 'scene_list.json') + with open(listfile, 'r') as f: + scenes = json.load(f) + + # for each of these, we will select some dslr images and some iphone images + # we will undistort them and render their depth + renderer = pyrender.OffscreenRenderer(0, 0) + for scene in tqdm(scenes, position=0, leave=True): + data_dir = os.path.join(root, 'data', scene) + dir_dslr = os.path.join(data_dir, 'dslr') + dir_iphone = os.path.join(data_dir, 'iphone') + dir_scans = os.path.join(data_dir, 'scans') + + assert os.path.isdir(data_dir) and os.path.isdir(dir_dslr) \ + and os.path.isdir(dir_iphone) and os.path.isdir(dir_scans) + + output_dir_scene = os.path.join(output_dir, scene) + scene_metadata_path = osp.join(output_dir_scene, 'scene_metadata.npz') + if osp.isfile(scene_metadata_path): + continue + + pairs_dir_scene = os.path.join(pairsdir, scene) + pairs_dir_scene_selected_pairs = os.path.join(pairs_dir_scene, 'selected_pairs.npz') + assert osp.isfile(pairs_dir_scene_selected_pairs) + selected_npz = np.load(pairs_dir_scene_selected_pairs) + selection, pairs = selected_npz['selection'], selected_npz['pairs'] + + # set up the output paths + output_dir_scene_rgb = os.path.join(output_dir_scene, 'images') + output_dir_scene_depth = os.path.join(output_dir_scene, 'depth') + os.makedirs(output_dir_scene_rgb, exist_ok=True) + os.makedirs(output_dir_scene_depth, exist_ok=True) + + ply_path = os.path.join(dir_scans, 'mesh_aligned_0.05.ply') + + sfm_dir_dslr = os.path.join(dir_dslr, 'colmap') + rgb_dir_dslr = os.path.join(dir_dslr, 'resized_images') + mask_dir_dslr = os.path.join(dir_dslr, 'resized_anon_masks') + + sfm_dir_iphone = os.path.join(dir_iphone, 'colmap') + rgb_dir_iphone = os.path.join(dir_iphone, 'rgb') + mask_dir_iphone = os.path.join(dir_iphone, 'rgb_masks') + + # load the mesh + with open(ply_path, 'rb') as f: + mesh_kwargs = trimesh.exchange.ply.load_ply(f) + mesh_scene = trimesh.Trimesh(**mesh_kwargs) + + # read colmap reconstruction, we will only use the intrinsics and pose here + img_idx_dslr, img_infos_dslr, points3D_dslr, observations_dslr = load_sfm(sfm_dir_dslr, cam_type='dslr') + dslr_paths = { + "in_colmap": sfm_dir_dslr, + "in_rgb": rgb_dir_dslr, + "in_mask": mask_dir_dslr, + } + + img_idx_iphone, img_infos_iphone, points3D_iphone, observations_iphone = load_sfm( + sfm_dir_iphone, cam_type='iphone') + iphone_paths = { + "in_colmap": sfm_dir_iphone, + "in_rgb": rgb_dir_iphone, + "in_mask": mask_dir_iphone, + } + + mesh = pyrender.Mesh.from_trimesh(mesh_scene, smooth=False) + pyrender_scene = pyrender.Scene() + pyrender_scene.add(mesh) + + selection_dslr = [imgname + '.JPG' for imgname in selection if imgname.startswith('DSC')] + selection_iphone = [imgname + '.jpg' for imgname in selection if imgname.startswith('frame_')] + + # resize the image to a more manageable size and render depth + for selection_cam, img_idx, img_infos, paths_data in [(selection_dslr, img_idx_dslr, img_infos_dslr, dslr_paths), + (selection_iphone, img_idx_iphone, img_infos_iphone, iphone_paths)]: + rgb_dir = paths_data['in_rgb'] + mask_dir = paths_data['in_mask'] + for imgname in tqdm(selection_cam, position=1, leave=False): + imgidx = img_idx[imgname] + img_infos_idx = img_infos[imgidx] + rgb = np.array(Image.open(os.path.join(rgb_dir, img_infos_idx['path']))) + mask = np.array(Image.open(os.path.join(mask_dir, img_infos_idx['path'][:-3] + 'png'))) + + _, _, K, rgb, mask = undistort_images(img_infos_idx['intrinsics'], rgb, mask) + + # rescale_image_depthmap assumes opencv intrinsics + intrinsics = geometry.colmap_to_opencv_intrinsics(K) + image, mask, intrinsics = rescale_image_depthmap( + rgb, mask, intrinsics, (target_resolution, target_resolution * 3.0 / 4)) + + W, H = image.size + intrinsics = geometry.opencv_to_colmap_intrinsics(intrinsics) + + # update inpace img_infos_idx + img_infos_idx['intrinsics'] = intrinsics + rgb_outpath = os.path.join(output_dir_scene_rgb, img_infos_idx['path'][:-3] + 'jpg') + image.save(rgb_outpath) + + depth_outpath = os.path.join(output_dir_scene_depth, img_infos_idx['path'][:-3] + 'png') + # render depth image + renderer.viewport_width, renderer.viewport_height = W, H + fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2] + camera = pyrender.camera.IntrinsicsCamera(fx, fy, cx, cy, znear=znear, zfar=zfar) + camera_node = pyrender_scene.add(camera, pose=img_infos_idx['cam_to_world'] @ OPENGL_TO_OPENCV) + + depth = renderer.render(pyrender_scene, flags=pyrender.RenderFlags.DEPTH_ONLY) + pyrender_scene.remove_node(camera_node) # dont forget to remove camera + + depth = (depth * 1000).astype('uint16') + # invalidate depth from mask before saving + depth_mask = (mask < 255) + depth[depth_mask] = 0 + Image.fromarray(depth).save(depth_outpath) + + trajectories = [] + intrinsics = [] + for imgname in selection: + if imgname.startswith('DSC'): + imgidx = img_idx_dslr[imgname + '.JPG'] + img_infos_idx = img_infos_dslr[imgidx] + elif imgname.startswith('frame_'): + imgidx = img_idx_iphone[imgname + '.jpg'] + img_infos_idx = img_infos_iphone[imgidx] + else: + raise ValueError('invalid image name') + + intrinsics.append(img_infos_idx['intrinsics']) + trajectories.append(img_infos_idx['cam_to_world']) + + intrinsics = np.stack(intrinsics, axis=0) + trajectories = np.stack(trajectories, axis=0) + # save metadata for this scene + np.savez(scene_metadata_path, + trajectories=trajectories, + intrinsics=intrinsics, + images=selection, + pairs=pairs) + + del img_infos + del pyrender_scene + + # concat all scene_metadata.npz into a single file + scene_data = {} + for scene_subdir in scenes: + scene_metadata_path = osp.join(output_dir, scene_subdir, 'scene_metadata.npz') + with np.load(scene_metadata_path) as data: + trajectories = data['trajectories'] + intrinsics = data['intrinsics'] + images = data['images'] + pairs = data['pairs'] + scene_data[scene_subdir] = {'trajectories': trajectories, + 'intrinsics': intrinsics, + 'images': images, + 'pairs': pairs} + + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + pairs = [] + for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): + num_imgs = data['images'].shape[0] + img_pairs = data['pairs'] + + scenes.append(scene_subdir) + sceneids.extend([scene_idx] * num_imgs) + + images.append(data['images']) + + intrinsics.append(data['intrinsics']) + trajectories.append(data['trajectories']) + + # offset pairs + img_pairs[:, 0:2] += offset + pairs.append(img_pairs) + counts.append(offset) + + offset += num_imgs + + images = np.concatenate(images, axis=0) + intrinsics = np.concatenate(intrinsics, axis=0) + trajectories = np.concatenate(trajectories, axis=0) + pairs = np.concatenate(pairs, axis=0) + np.savez(osp.join(output_dir, 'all_metadata.npz'), + counts=counts, + scenes=scenes, + sceneids=sceneids, + images=images, + intrinsics=intrinsics, + trajectories=trajectories, + pairs=pairs) + print('all done') + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + if args.pyopengl_platform.strip(): + os.environ['PYOPENGL_PLATFORM'] = args.pyopengl_platform + process_scenes(args.scannetpp_dir, args.precomputed_pairs, args.output_dir, args.target_resolution) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_staticthings3d.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_staticthings3d.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3eec16321c14b12291699f1fee492b5a7d8b1c --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_staticthings3d.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the StaticThings3D dataset +# dataset at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d +# 1) Download StaticThings3D in /path/to/StaticThings3D/ +# with the script at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/scripts/download_staticthings3d.sh +# --> depths.tar.bz2 frames_finalpass.tar.bz2 poses.tar.bz2 frames_cleanpass.tar.bz2 intrinsics.tar.bz2 +# 2) unzip everything in the same /path/to/StaticThings3D/ directory +# 5) python datasets_preprocess/preprocess_staticthings3d.py --StaticThings3D_dir /path/to/tmp/StaticThings3D/ +# -------------------------------------------------------- +import os +import os.path as osp +import re +from tqdm import tqdm +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import path_to_root # noqa +from dust3r.utils.parallel import parallel_threads +from dust3r.datasets.utils import cropping # noqa + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--StaticThings3D_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/staticthings3d_processed') + return parser + + +def main(db_root, pairs_path, output_dir): + all_scenes = _list_all_scenes(db_root) + + # crop images + args = [(db_root, osp.join(split, subsplit, seq), camera, f'{n:04d}', output_dir) + for split, subsplit, seq in all_scenes for camera in ['left', 'right'] for n in range(6, 16)] + parallel_threads(load_crop_and_save, args, star_args=True, front_num=1) + + # verify that all images are there + CAM = {b'l': 'left', b'r': 'right'} + pairs = np.load(pairs_path) + for scene, seq, cam1, im1, cam2, im2 in tqdm(pairs): + seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') + for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: + for ext in ['clean', 'final']: + impath = osp.join(output_dir, seq_path, cam, f"{idx:04n}_{ext}.jpg") + assert osp.isfile(impath), f'missing an image at {impath=}' + + print(f'>> Saved all data to {output_dir}!') + + +def load_crop_and_save(db_root, relpath_, camera, num, out_dir): + relpath = osp.join(relpath_, camera, num) + if osp.isfile(osp.join(out_dir, relpath + '.npz')): + return + os.makedirs(osp.join(out_dir, relpath_, camera), exist_ok=True) + + # load everything + intrinsics_in = readFloat(osp.join(db_root, 'intrinsics', relpath_, num + '.float3')) + cam2world = np.linalg.inv(readFloat(osp.join(db_root, 'poses', relpath + '.float3'))) + depthmap_in = readFloat(osp.join(db_root, 'depths', relpath + '.float3')) + img_clean = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_cleanpass', + relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + img_final = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_finalpass', + relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + + # do the crop + assert img_clean.shape[:2] == (540, 960) + assert img_final.shape[:2] == (540, 960) + (clean_out, final_out), depthmap, intrinsics_out, R_in2out = _crop_image( + intrinsics_in, (img_clean, img_final), depthmap_in, (512, 384)) + + # write everything + clean_out.save(osp.join(out_dir, relpath + '_clean.jpg'), quality=80) + final_out.save(osp.join(out_dir, relpath + '_final.jpg'), quality=80) + cv2.imwrite(osp.join(out_dir, relpath + '.exr'), depthmap) + + # New camera parameters + cam2world[:3, :3] = cam2world[:3, :3] @ R_in2out.T + np.savez(osp.join(out_dir, relpath + '.npz'), intrinsics=intrinsics_out, cam2world=cam2world) + + +def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(512, 512)): + image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( + color_image_in, depthmap_in, intrinsics_in, resolution_out) + R_in2out = np.eye(3) + return image, depthmap, intrinsics_out, R_in2out + + +def _list_all_scenes(path): + print('>> Listing all scenes') + + res = [] + for split in ['TRAIN']: + for subsplit in 'ABC': + for seq in os.listdir(osp.join(path, 'intrinsics', split, subsplit)): + res.append((split, subsplit, seq)) + print(f' (found ({len(res)}) scenes)') + assert res, f'Did not find anything at {path=}' + return res + + +def readFloat(name): + with open(name, 'rb') as f: + if (f.readline().decode("utf-8")) != 'float\n': + raise Exception('float file %s did not contain keyword' % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + data = np.fromfile(f, np.float32, count).reshape(dims) + return data # Hxw or CxHxW NxCxHxW + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.StaticThings3D_dir, args.precomputed_pairs, args.output_dir) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_waymo.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..203f337330a7e06e61d2fb9dd99647063967922d --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_waymo.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Preprocessing code for the WayMo Open dataset +# dataset at https://github.com/waymo-research/waymo-open-dataset +# 1) Accept the license +# 2) download all training/*.tfrecord files from Perception Dataset, version 1.4.2 +# 3) put all .tfrecord files in '/path/to/waymo_dir' +# 4) install the waymo_open_dataset package with +# `python3 -m pip install gcsfs waymo-open-dataset-tf-2-12-0==1.6.4` +# 5) execute this script as `python preprocess_waymo.py --waymo_dir /path/to/waymo_dir` +# -------------------------------------------------------- +import sys +import os +import os.path as osp +import shutil +import json +from tqdm import tqdm +import PIL.Image +import numpy as np +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +import tensorflow.compat.v1 as tf +tf.enable_eager_execution() + +import path_to_root # noqa +from dust3r.utils.geometry import geotrf, inv +from dust3r.utils.image import imread_cv2 +from dust3r.utils.parallel import parallel_processes as parallel_map +from dust3r.datasets.utils import cropping +from dust3r.viz import show_raw_pointcloud + + +def get_parser(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--waymo_dir', required=True) + parser.add_argument('--precomputed_pairs', required=True) + parser.add_argument('--output_dir', default='data/waymo_processed') + parser.add_argument('--workers', type=int, default=1) + return parser + + +def main(waymo_root, pairs_path, output_dir, workers=1): + extract_frames(waymo_root, output_dir, workers=workers) + make_crops(output_dir, workers=args.workers) + + # make sure all pairs are there + with np.load(pairs_path) as data: + scenes = data['scenes'] + frames = data['frames'] + pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) + + for scene_id, im1_id, im2_id in pairs: + for im_id in (im1_id, im2_id): + path = osp.join(output_dir, scenes[scene_id], frames[im_id] + '.jpg') + assert osp.isfile(path), f'Missing a file at {path=}\nDid you download all .tfrecord files?' + + shutil.rmtree(osp.join(output_dir, 'tmp')) + print('Done! all data generated at', output_dir) + + +def _list_sequences(db_root): + print('>> Looking for sequences in', db_root) + res = sorted(f for f in os.listdir(db_root) if f.endswith('.tfrecord')) + print(f' found {len(res)} sequences') + return res + + +def extract_frames(db_root, output_dir, workers=8): + sequences = _list_sequences(db_root) + output_dir = osp.join(output_dir, 'tmp') + print('>> outputing result to', output_dir) + args = [(db_root, output_dir, seq) for seq in sequences] + parallel_map(process_one_seq, args, star_args=True, workers=workers) + + +def process_one_seq(db_root, output_dir, seq): + out_dir = osp.join(output_dir, seq) + os.makedirs(out_dir, exist_ok=True) + calib_path = osp.join(out_dir, 'calib.json') + if osp.isfile(calib_path): + return + + try: + with tf.device('/CPU:0'): + calib, frames = extract_frames_one_seq(osp.join(db_root, seq)) + except RuntimeError: + print(f'/!\\ Error with sequence {seq} /!\\', file=sys.stderr) + return # nothing is saved + + for f, (frame_name, views) in enumerate(tqdm(frames, leave=False)): + for cam_idx, view in views.items(): + img = PIL.Image.fromarray(view.pop('img')) + img.save(osp.join(out_dir, f'{f:05d}_{cam_idx}.jpg')) + np.savez(osp.join(out_dir, f'{f:05d}_{cam_idx}.npz'), **view) + + with open(calib_path, 'w') as f: + json.dump(calib, f) + + +def extract_frames_one_seq(filename): + from waymo_open_dataset import dataset_pb2 as open_dataset + from waymo_open_dataset.utils import frame_utils + + print('>> Opening', filename) + dataset = tf.data.TFRecordDataset(filename, compression_type='') + + calib = None + frames = [] + + for data in tqdm(dataset, leave=False): + frame = open_dataset.Frame() + frame.ParseFromString(bytearray(data.numpy())) + + content = frame_utils.parse_range_image_and_camera_projection(frame) + range_images, camera_projections, _, range_image_top_pose = content + + views = {} + frames.append((frame.context.name, views)) + + # once in a sequence, read camera calibration info + if calib is None: + calib = [] + for cam in frame.context.camera_calibrations: + calib.append((cam.name, + dict(width=cam.width, + height=cam.height, + intrinsics=list(cam.intrinsic), + extrinsics=list(cam.extrinsic.transform)))) + + # convert LIDAR to pointcloud + points, cp_points = frame_utils.convert_range_image_to_point_cloud( + frame, + range_images, + camera_projections, + range_image_top_pose) + + # 3d points in vehicle frame. + points_all = np.concatenate(points, axis=0) + cp_points_all = np.concatenate(cp_points, axis=0) + + # The distance between lidar points and vehicle frame origin. + cp_points_all_tensor = tf.constant(cp_points_all, dtype=tf.int32) + + for i, image in enumerate(frame.images): + # select relevant 3D points for this view + mask = tf.equal(cp_points_all_tensor[..., 0], image.name) + cp_points_msk_tensor = tf.cast(tf.gather_nd(cp_points_all_tensor, tf.where(mask)), dtype=tf.float32) + + pose = np.asarray(image.pose.transform).reshape(4, 4) + timestamp = image.pose_timestamp + + rgb = tf.image.decode_jpeg(image.image).numpy() + + pix = cp_points_msk_tensor[..., 1:3].numpy().round().astype(np.int16) + pts3d = points_all[mask.numpy()] + + views[image.name] = dict(img=rgb, pose=pose, pixels=pix, pts3d=pts3d, timestamp=timestamp) + + if not 'show full point cloud': + show_raw_pointcloud([v['pts3d'] for v in views.values()], [v['img'] for v in views.values()]) + + return calib, frames + + +def make_crops(output_dir, workers=16, **kw): + tmp_dir = osp.join(output_dir, 'tmp') + sequences = _list_sequences(tmp_dir) + args = [(tmp_dir, output_dir, seq) for seq in sequences] + parallel_map(crop_one_seq, args, star_args=True, workers=workers, front_num=0) + + +def crop_one_seq(input_dir, output_dir, seq, resolution=512): + seq_dir = osp.join(input_dir, seq) + out_dir = osp.join(output_dir, seq) + if osp.isfile(osp.join(out_dir, '00100_1.jpg')): + return + os.makedirs(out_dir, exist_ok=True) + + # load calibration file + try: + with open(osp.join(seq_dir, 'calib.json')) as f: + calib = json.load(f) + except IOError: + print(f'/!\\ Error: Missing calib.json in sequence {seq} /!\\', file=sys.stderr) + return + + axes_transformation = np.array([ + [0, -1, 0, 0], + [0, 0, -1, 0], + [1, 0, 0, 0], + [0, 0, 0, 1]]) + + cam_K = {} + cam_distortion = {} + cam_res = {} + cam_to_car = {} + for cam_idx, cam_info in calib: + cam_idx = str(cam_idx) + cam_res[cam_idx] = (W, H) = (cam_info['width'], cam_info['height']) + f1, f2, cx, cy, k1, k2, p1, p2, k3 = cam_info['intrinsics'] + cam_K[cam_idx] = np.asarray([(f1, 0, cx), (0, f2, cy), (0, 0, 1)]) + cam_distortion[cam_idx] = np.asarray([k1, k2, p1, p2, k3]) + cam_to_car[cam_idx] = np.asarray(cam_info['extrinsics']).reshape(4, 4) # cam-to-vehicle + + frames = sorted(f[:-3] for f in os.listdir(seq_dir) if f.endswith('.jpg')) + + # from dust3r.viz import SceneViz + # viz = SceneViz() + + for frame in tqdm(frames, leave=False): + cam_idx = frame[-2] # cam index + assert cam_idx in '12345', f'bad {cam_idx=} in {frame=}' + data = np.load(osp.join(seq_dir, frame + 'npz')) + car_to_world = data['pose'] + W, H = cam_res[cam_idx] + + # load depthmap + pos2d = data['pixels'].round().astype(np.uint16) + x, y = pos2d.T + pts3d = data['pts3d'] # already in the car frame + pts3d = geotrf(axes_transformation @ inv(cam_to_car[cam_idx]), pts3d) + # X=LEFT_RIGHT y=ALTITUDE z=DEPTH + + # load image + image = imread_cv2(osp.join(seq_dir, frame + 'jpg')) + + # downscale image + output_resolution = (resolution, 1) if W > H else (1, resolution) + image, _, intrinsics2 = cropping.rescale_image_depthmap(image, None, cam_K[cam_idx], output_resolution) + image.save(osp.join(out_dir, frame + 'jpg'), quality=80) + + # save as an EXR file? yes it's smaller (and easier to load) + W, H = image.size + depthmap = np.zeros((H, W), dtype=np.float32) + pos2d = geotrf(intrinsics2 @ inv(cam_K[cam_idx]), pos2d).round().astype(np.int16) + x, y = pos2d.T + depthmap[y.clip(min=0, max=H - 1), x.clip(min=0, max=W - 1)] = pts3d[:, 2] + cv2.imwrite(osp.join(out_dir, frame + 'exr'), depthmap) + + # save camera parametes + cam2world = car_to_world @ cam_to_car[cam_idx] @ inv(axes_transformation) + np.savez(osp.join(out_dir, frame + 'npz'), intrinsics=intrinsics2, + cam2world=cam2world, distortion=cam_distortion[cam_idx]) + + # viz.add_rgbd(np.asarray(image), depthmap, intrinsics2, cam2world) + # viz.show() + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args.waymo_dir, args.precomputed_pairs, args.output_dir, workers=args.workers) diff --git a/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_wildrgbd.py b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..ff3f0f7abb7d9ef43bba6a7c6cd6f4e652a8f510 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/datasets_preprocess/preprocess_wildrgbd.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Script to pre-process the WildRGB-D dataset. +# Usage: +# python3 datasets_preprocess/preprocess_wildrgbd.py --wildrgbd_dir /path/to/wildrgbd +# -------------------------------------------------------- + +import argparse +import random +import json +import os +import os.path as osp + +import PIL.Image +import numpy as np +import cv2 + +from tqdm.auto import tqdm +import matplotlib.pyplot as plt + +import path_to_root # noqa +import dust3r.datasets.utils.cropping as cropping # noqa +from dust3r.utils.image import imread_cv2 + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, default="data/wildrgbd_processed") + parser.add_argument("--wildrgbd_dir", type=str, required=True) + parser.add_argument("--train_num_sequences_per_object", type=int, default=50) + parser.add_argument("--test_num_sequences_per_object", type=int, default=10) + parser.add_argument("--num_frames", type=int, default=100) + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--img_size", type=int, default=512, + help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size")) + return parser + + +def get_set_list(category_dir, split): + listfiles = ["camera_eval_list.json", "nvs_list.json"] + + sequences_all = {s: {k: set() for k in listfiles} for s in ['train', 'val']} + for listfile in listfiles: + with open(osp.join(category_dir, listfile)) as f: + subset_lists_data = json.load(f) + for s in ['train', 'val']: + sequences_all[s][listfile].update(subset_lists_data[s]) + train_intersection = set.intersection(*list(sequences_all['train'].values())) + if split == "train": + return train_intersection + else: + all_seqs = set.union(*list(sequences_all['train'].values()), *list(sequences_all['val'].values())) + return all_seqs.difference(train_intersection) + + +def prepare_sequences(category, wildrgbd_dir, output_dir, img_size, split, max_num_sequences_per_object, + output_num_frames, seed): + random.seed(seed) + category_dir = osp.join(wildrgbd_dir, category) + category_output_dir = osp.join(output_dir, category) + sequences_all = get_set_list(category_dir, split) + sequences_all = sorted(sequences_all) + + sequences_all_tmp = [] + for seq_name in sequences_all: + scene_dir = osp.join(wildrgbd_dir, category_dir, seq_name) + if not os.path.isdir(scene_dir): + print(f'{scene_dir} does not exist, skipped') + continue + sequences_all_tmp.append(seq_name) + sequences_all = sequences_all_tmp + if len(sequences_all) <= max_num_sequences_per_object: + selected_sequences = sequences_all + else: + selected_sequences = random.sample(sequences_all, max_num_sequences_per_object) + + selected_sequences_numbers_dict = {} + for seq_name in tqdm(selected_sequences, leave=False): + scene_dir = osp.join(category_dir, seq_name) + scene_output_dir = osp.join(category_output_dir, seq_name) + with open(osp.join(scene_dir, 'metadata'), 'r') as f: + metadata = json.load(f) + + K = np.array(metadata["K"]).reshape(3, 3).T + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + w, h = metadata["w"], metadata["h"] + + camera_intrinsics = np.array( + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + ) + camera_to_world_path = os.path.join(scene_dir, 'cam_poses.txt') + camera_to_world_content = np.genfromtxt(camera_to_world_path) + camera_to_world = camera_to_world_content[:, 1:].reshape(-1, 4, 4) + + frame_idx = camera_to_world_content[:, 0] + num_frames = frame_idx.shape[0] + assert num_frames >= output_num_frames + assert np.all(frame_idx == np.arange(num_frames)) + + # selected_sequences_numbers_dict[seq_name] = num_frames + + selected_frames = np.round(np.linspace(0, num_frames - 1, output_num_frames)).astype(int).tolist() + selected_sequences_numbers_dict[seq_name] = selected_frames + + for frame_id in tqdm(selected_frames): + depth_path = os.path.join(scene_dir, 'depth', f'{frame_id:0>5d}.png') + masks_path = os.path.join(scene_dir, 'masks', f'{frame_id:0>5d}.png') + rgb_path = os.path.join(scene_dir, 'rgb', f'{frame_id:0>5d}.png') + + input_rgb_image = PIL.Image.open(rgb_path).convert('RGB') + input_mask = plt.imread(masks_path) + input_depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float64) + depth_mask = np.stack((input_depthmap, input_mask), axis=-1) + H, W = input_depthmap.shape + + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = int(cx - min_margin_x), int(cy - min_margin_y) + r, b = int(cx + min_margin_x), int(cy + min_margin_y) + crop_bbox = (l, t, r, b) + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( + input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) + + # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 + scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + if max(output_resolution) < img_size: + # let's put the max dimension to img_size + scale_final = (img_size / max(H, W)) + 1e-8 + output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) + + input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( + input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) + input_depthmap = depth_mask[:, :, 0] + input_mask = depth_mask[:, :, 1] + + camera_pose = camera_to_world[frame_id] + + # save crop images and depth, metadata + save_img_path = os.path.join(scene_output_dir, 'rgb', f'{frame_id:0>5d}.jpg') + save_depth_path = os.path.join(scene_output_dir, 'depth', f'{frame_id:0>5d}.png') + save_mask_path = os.path.join(scene_output_dir, 'masks', f'{frame_id:0>5d}.png') + os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) + os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) + + input_rgb_image.save(save_img_path) + cv2.imwrite(save_depth_path, input_depthmap.astype(np.uint16)) + cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) + + save_meta_path = os.path.join(scene_output_dir, 'metadata', f'{frame_id:0>5d}.npz') + os.makedirs(os.path.split(save_meta_path)[0], exist_ok=True) + np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, + camera_pose=camera_pose) + + return selected_sequences_numbers_dict + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + assert args.wildrgbd_dir != args.output_dir + + categories = sorted([ + dirname for dirname in os.listdir(args.wildrgbd_dir) + if os.path.isdir(os.path.join(args.wildrgbd_dir, dirname, 'scenes')) + ]) + + os.makedirs(args.output_dir, exist_ok=True) + + splits_num_sequences_per_object = [args.train_num_sequences_per_object, args.test_num_sequences_per_object] + for split, num_sequences_per_object in zip(['train', 'test'], splits_num_sequences_per_object): + selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(selected_sequences_path): + continue + all_selected_sequences = {} + for category in categories: + category_output_dir = osp.join(args.output_dir, category) + os.makedirs(category_output_dir, exist_ok=True) + category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') + if os.path.isfile(category_selected_sequences_path): + with open(category_selected_sequences_path, 'r') as fid: + category_selected_sequences = json.load(fid) + else: + print(f"Processing {split} - category = {category}") + category_selected_sequences = prepare_sequences( + category=category, + wildrgbd_dir=args.wildrgbd_dir, + output_dir=args.output_dir, + img_size=args.img_size, + split=split, + max_num_sequences_per_object=num_sequences_per_object, + output_num_frames=args.num_frames, + seed=args.seed + int("category".encode('ascii').hex(), 16), + ) + with open(category_selected_sequences_path, 'w') as file: + json.dump(category_selected_sequences, file) + + all_selected_sequences[category] = category_selected_sequences + with open(selected_sequences_path, 'w') as file: + json.dump(all_selected_sequences, file) diff --git a/imcui/third_party/mast3r/dust3r/demo.py b/imcui/third_party/mast3r/dust3r/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..326c6e5a49d5d352b4afb5445cee5d22571c3bdd --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/demo.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dust3r gradio demo executable +# -------------------------------------------------------- +import os +import torch +import tempfile + +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.demo import get_args_parser, main_demo, set_print_with_timestamp + +import matplotlib.pyplot as pl +pl.ion() + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + set_print_with_timestamp() + + if args.tmp_dir is not None: + tmp_path = args.tmp_dir + os.makedirs(tmp_path, exist_ok=True) + tempfile.tempdir = tmp_path + + if args.server_name is not None: + server_name = args.server_name + else: + server_name = '0.0.0.0' if args.local_network else '127.0.0.1' + + if args.weights is not None: + weights_path = args.weights + else: + weights_path = "naver/" + args.model_name + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + # dust3r will write the 3D model inside tmpdirname + with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: + if not args.silent: + print('Outputing stuff in', tmpdirname) + main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent) diff --git a/imcui/third_party/mast3r/dust3r/docker/docker-compose-cpu.yml b/imcui/third_party/mast3r/dust3r/docker/docker-compose-cpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..2015fd771e8b6246d288c03a38f6fbb3f17dff20 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/docker/docker-compose-cpu.yml @@ -0,0 +1,16 @@ +version: '3.8' +services: + dust3r-demo: + build: + context: ./files + dockerfile: cpu.Dockerfile + ports: + - "7860:7860" + volumes: + - ./files/checkpoints:/dust3r/checkpoints + environment: + - DEVICE=cpu + - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth} + cap_add: + - IPC_LOCK + - SYS_RESOURCE diff --git a/imcui/third_party/mast3r/dust3r/docker/docker-compose-cuda.yml b/imcui/third_party/mast3r/dust3r/docker/docker-compose-cuda.yml new file mode 100644 index 0000000000000000000000000000000000000000..85710af953d669fe618273de6ce3a062a7a84cca --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/docker/docker-compose-cuda.yml @@ -0,0 +1,23 @@ +version: '3.8' +services: + dust3r-demo: + build: + context: ./files + dockerfile: cuda.Dockerfile + ports: + - "7860:7860" + environment: + - DEVICE=cuda + - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth} + volumes: + - ./files/checkpoints:/dust3r/checkpoints + cap_add: + - IPC_LOCK + - SYS_RESOURCE + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/imcui/third_party/mast3r/dust3r/dust3r/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faf5cd279a317c1efb9ba947682992c0949c1bdc --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/__init__.py @@ -0,0 +1,33 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# global alignment optimization wrapper function +# -------------------------------------------------------- +from enum import Enum + +from .optimizer import PointCloudOptimizer +from .modular_optimizer import ModularPointCloudOptimizer +from .pair_viewer import PairViewer + + +class GlobalAlignerMode(Enum): + PointCloudOptimizer = "PointCloudOptimizer" + ModularPointCloudOptimizer = "ModularPointCloudOptimizer" + PairViewer = "PairViewer" + + +def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): + # extract all inputs + view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] + # build the optimizer + if mode == GlobalAlignerMode.PointCloudOptimizer: + net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: + net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.PairViewer: + net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) + else: + raise NotImplementedError(f'Unknown mode {mode}') + + return net diff --git a/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/base_opt.py b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/base_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..4d36e05bfca80509bced20add7c067987d538951 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/base_opt.py @@ -0,0 +1,405 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class for the global alignement procedure +# -------------------------------------------------------- +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn +import roma +from copy import deepcopy +import tqdm + +from dust3r.utils.geometry import inv, geotrf +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb +from dust3r.viz import SceneViz, segment_sky, auto_cam_size +from dust3r.optim_factory import adjust_learning_rate_by_lr + +from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, + cosine_schedule, linear_schedule, get_conf_trf) +import dust3r.cloud_opt.init_im_poses as init_fun + + +class BasePCOptimizer (nn.Module): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + other = deepcopy(args[0]) + attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes + min_conf_thr conf_thr conf_i conf_j im_conf + base_scale norm_pw_scale POSE_DIM pw_poses + pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() + self.__dict__.update({k: other[k] for k in attrs}) + else: + self._init_from_views(*args, **kwargs) + + def _init_from_views(self, view1, view2, pred1, pred2, + dist='l1', + conf='log', + min_conf_thr=3, + base_scale=0.5, + allow_pw_adaptors=False, + pw_break=20, + rand_pose=torch.randn, + iterationsCount=None, + verbose=True): + super().__init__() + if not isinstance(view1['idx'], list): + view1['idx'] = view1['idx'].tolist() + if not isinstance(view2['idx'], list): + view2['idx'] = view2['idx'].tolist() + self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} + self.dist = ALL_DISTS[dist] + self.verbose = verbose + + self.n_imgs = self._check_edges() + + # input data + pred1_pts = pred1['pts3d'] + pred2_pts = pred2['pts3d_in_other_view'] + self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) + self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) + self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) + + # work in log-scale with conf + pred1_conf = pred1['conf'] + pred2_conf = pred2['conf'] + self.min_conf_thr = min_conf_thr + self.conf_trf = get_conf_trf(conf) + + self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}) + self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}) + self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) + for i in range(len(self.im_conf)): + self.im_conf[i].requires_grad = False + + # pairwise pose parameters + self.base_scale = base_scale + self.norm_pw_scale = True + self.pw_break = pw_break + self.POSE_DIM = 7 + self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses + self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation + self.pw_adaptors.requires_grad_(allow_pw_adaptors) + self.has_im_poses = False + self.rand_pose = rand_pose + + # possibly store images for show_pointcloud + self.imgs = None + if 'img' in view1 and 'img' in view2: + imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] + for v in range(len(self.edges)): + idx = view1['idx'][v] + imgs[idx] = view1['img'][v] + idx = view2['idx'][v] + imgs[idx] = view2['img'][v] + self.imgs = rgb(imgs) + + @property + def n_edges(self): + return len(self.edges) + + @property + def str_edges(self): + return [edge_str(i, j) for i, j in self.edges] + + @property + def imsizes(self): + return [(w, h) for h, w in self.imshapes] + + @property + def device(self): + return next(iter(self.parameters())).device + + def state_dict(self, trainable=True): + all_params = super().state_dict() + return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} + + def load_state_dict(self, data): + return super().load_state_dict(self.state_dict(trainable=False) | data) + + def _check_edges(self): + indices = sorted({i for edge in self.edges for i in edge}) + assert indices == list(range(len(indices))), 'bad pair indices: missing values ' + return len(indices) + + @torch.no_grad() + def _compute_img_conf(self, pred1_conf, pred2_conf): + im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) + for e, (i, j) in enumerate(self.edges): + im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) + im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) + return im_conf + + def get_adaptors(self): + adapt = self.pw_adaptors + adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) + if self.norm_pw_scale: # normalize so that the product == 1 + adapt = adapt - adapt.mean(dim=1, keepdim=True) + return (adapt / self.pw_break).exp() + + def _get_poses(self, poses): + # normalize rotation + Q = poses[:, :4] + T = signed_expm1(poses[:, 4:7]) + RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() + return RT + + def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): + # all poses == cam-to-world + pose = poses[idx] + if not (pose.requires_grad or force): + return pose + + if R.shape == (4, 4): + assert T is None + T = R[:3, 3] + R = R[:3, :3] + + if R is not None: + pose.data[0:4] = roma.rotmat_to_unitquat(R) + if T is not None: + pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale + + if scale is not None: + assert poses.shape[-1] in (8, 13) + pose.data[-1] = np.log(float(scale)) + return pose + + def get_pw_norm_scale_factor(self): + if self.norm_pw_scale: + # normalize scales so that things cannot go south + # we want that exp(scale) ~= self.base_scale + return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() + else: + return 1 # don't norm scale for known poses + + def get_pw_scale(self): + scale = self.pw_poses[:, -1].exp() # (n_edges,) + scale = scale * self.get_pw_norm_scale_factor() + return scale + + def get_pw_poses(self): # cam to world + RT = self._get_poses(self.pw_poses) + scaled_RT = RT.clone() + scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation + return scaled_RT + + def get_masks(self): + return [(conf > self.min_conf_thr) for conf in self.im_conf] + + def depth_to_pts3d(self): + raise NotImplementedError() + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def _set_focal(self, idx, focal, force=False): + raise NotImplementedError() + + def get_focals(self): + raise NotImplementedError() + + def get_known_focal_mask(self): + raise NotImplementedError() + + def get_principal_points(self): + raise NotImplementedError() + + def get_conf(self, mode=None): + trf = self.conf_trf if mode is None else get_conf_trf(mode) + return [trf(c) for c in self.im_conf] + + def get_im_poses(self): + raise NotImplementedError() + + def _set_depthmap(self, idx, depth, force=False): + raise NotImplementedError() + + def get_depthmaps(self, raw=False): + raise NotImplementedError() + + def clean_pointcloud(self, **kw): + cams = inv(self.get_im_poses()) + K = self.get_intrinsics() + depthmaps = self.get_depthmaps() + all_pts3d = self.get_pts3d() + + new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw) + + for i, new_conf in enumerate(new_im_confs): + self.im_conf[i].data[:] = new_conf + return self + + def forward(self, ret_details=False): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors() + proj_pts3d = self.get_pts3d() + # pre-compute pixel weights + weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} + weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} + + loss = 0 + if ret_details: + details = -torch.ones((self.n_imgs, self.n_imgs)) + + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # distance in image i and j + aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) + aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) + li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() + lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() + loss = loss + li + lj + + if ret_details: + details[i, j] = li + lj + loss /= self.n_edges # average over all pairs + + if ret_details: + return loss, details + return loss + + @torch.cuda.amp.autocast(enabled=False) + def compute_global_alignment(self, init=None, niter_PnP=10, **kw): + if init is None: + pass + elif init == 'msp' or init == 'mst': + init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) + elif init == 'known_poses': + init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, + niter_PnP=niter_PnP) + else: + raise ValueError(f'bad value for {init=}') + + return global_alignment_loop(self, **kw) + + @torch.no_grad() + def mask_sky(self): + res = deepcopy(self) + for i in range(self.n_imgs): + sky = segment_sky(self.imgs[i]) + res.im_conf[i][sky] = 0 + return res + + def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): + viz = SceneViz() + if self.imgs is None: + colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) + colors = list(map(tuple, colors.tolist())) + for n in range(self.n_imgs): + viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) + else: + viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) + colors = np.random.randint(256, size=(self.n_imgs, 3)) + + # camera poses + im_poses = to_numpy(self.get_im_poses()) + if cam_size is None: + cam_size = auto_cam_size(im_poses) + viz.add_cameras(im_poses, self.get_focals(), colors=colors, + images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) + if show_pw_cams: + pw_poses = self.get_pw_poses() + viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) + + if show_pw_pts3d: + pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] + viz.add_pointcloud(pts, (128, 0, 128)) + + viz.show(**kw) + return viz + + +def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6): + params = [p for p in net.parameters() if p.requires_grad] + if not params: + return net + + verbose = net.verbose + if verbose: + print('Global alignement - optimizing for:') + print([name for name, value in net.named_parameters() if value.requires_grad]) + + lr_base = lr + optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) + + loss = float('inf') + if verbose: + with tqdm.tqdm(total=niter) as bar: + while bar.n < bar.total: + loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule) + bar.set_postfix_str(f'{lr=:g} loss={loss:g}') + bar.update() + else: + for n in range(niter): + loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule) + return loss + + +def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): + t = cur_iter / niter + if schedule == 'cosine': + lr = cosine_schedule(t, lr_base, lr_min) + elif schedule == 'linear': + lr = linear_schedule(t, lr_base, lr_min) + else: + raise ValueError(f'bad lr {schedule=}') + adjust_learning_rate_by_lr(optimizer, lr) + optimizer.zero_grad() + loss = net() + loss.backward() + optimizer.step() + + return float(loss), lr + + +@torch.no_grad() +def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, + tol=0.001, bad_conf=0, dbg=()): + """ Method: + 1) express all 3d points in each camera coordinate frame + 2) if they're in front of a depthmap --> then lower their confidence + """ + assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) + assert 0 <= tol < 1 + res = [c.clone() for c in im_confs] + + # reshape appropriately + all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)] + depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)] + + for i, pts3d in enumerate(all_pts3d): + for j in range(len(all_pts3d)): + if i == j: continue + + # project 3dpts in other view + proj = geotrf(cams[j], pts3d) + proj_depth = proj[:,:,2] + u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) + + # check which points are actually in the visible cone + H, W = im_confs[j].shape + msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) + msk_j = v[msk_i], u[msk_i] + + # find bad points = those in front but less confident + bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j]) + + bad_msk_i = msk_i.clone() + bad_msk_i[msk_i] = bad_points + res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) + + return res diff --git a/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/commons.py b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..3be9f855a69ea18c82dcc8e5769e0149a59649bd --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/commons.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utility functions for global alignment +# -------------------------------------------------------- +import torch +import torch.nn as nn +import numpy as np + + +def edge_str(i, j): + return f'{i}_{j}' + + +def i_j_ij(ij): + return edge_str(*ij), ij + + +def edge_conf(conf_i, conf_j, edge): + return float(conf_i[edge].mean() * conf_j[edge].mean()) + + +def compute_edge_scores(edges, conf_i, conf_j): + return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} + + +def NoGradParamDict(x): + assert isinstance(x, dict) + return nn.ParameterDict(x).requires_grad_(False) + + +def get_imshapes(edges, pred_i, pred_j): + n_imgs = max(max(e) for e in edges) + 1 + imshapes = [None] * n_imgs + for e, (i, j) in enumerate(edges): + shape_i = tuple(pred_i[e].shape[0:2]) + shape_j = tuple(pred_j[e].shape[0:2]) + if imshapes[i]: + assert imshapes[i] == shape_i, f'incorrect shape for image {i}' + if imshapes[j]: + assert imshapes[j] == shape_j, f'incorrect shape for image {j}' + imshapes[i] = shape_i + imshapes[j] = shape_j + return imshapes + + +def get_conf_trf(mode): + if mode == 'log': + def conf_trf(x): return x.log() + elif mode == 'sqrt': + def conf_trf(x): return x.sqrt() + elif mode == 'm1': + def conf_trf(x): return x-1 + elif mode in ('id', 'none'): + def conf_trf(x): return x + else: + raise ValueError(f'bad mode for {mode=}') + return conf_trf + + +def l2_dist(a, b, weight): + return ((a - b).square().sum(dim=-1) * weight) + + +def l1_dist(a, b, weight): + return ((a - b).norm(dim=-1) * weight) + + +ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) + + +def signed_log1p(x): + sign = torch.sign(x) + return sign * torch.log1p(torch.abs(x)) + + +def signed_expm1(x): + sign = torch.sign(x) + return sign * torch.expm1(torch.abs(x)) + + +def cosine_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 + + +def linear_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_start + (lr_end - lr_start) * t diff --git a/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/init_im_poses.py b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/init_im_poses.py new file mode 100644 index 0000000000000000000000000000000000000000..7887c5cde27115273601e704b81ca0b0301f3715 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/init_im_poses.py @@ -0,0 +1,316 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Initialization functions for global alignment +# -------------------------------------------------------- +from functools import cache + +import numpy as np +import scipy.sparse as sp +import torch +import cv2 +import roma +from tqdm import tqdm + +from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses +from dust3r.post_process import estimate_focal_knowing_depth +from dust3r.viz import to_numpy + +from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores + + +@torch.no_grad() +def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3): + device = self.device + + # indices of known poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + assert nkp == self.n_imgs, 'not all poses are known' + + # get all focals + nkf, _, im_focals = get_known_focals(self) + assert nkf == self.n_imgs + im_pp = self.get_principal_points() + + best_depthmaps = {} + # init all pairwise poses + for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)): + i_j = edge_str(i, j) + + # find relative pose for this pair + P1 = torch.eye(4, device=device) + msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1) + _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()), + pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP) + + # align the two predicted camera with the two gt cameras + s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]]) + # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1 + # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # remember if this is a good depthmap + score = float(self.conf_i[i_j].mean()) + if score > best_depthmaps.get(i, (0,))[0]: + best_depthmaps[i] = score, i_j, s + + # init all image poses + for n in range(self.n_imgs): + assert known_poses_msk[n] + _, i_j, scale = best_depthmaps[n] + depth = self.pred_i[i_j][:, :, 2] + self._set_depthmap(n, depth * scale) + + +@torch.no_grad() +def init_minimum_spanning_tree(self, **kw): + """ Init all camera poses (image-wise and pairwise poses) given + an initial set of pairwise estimations. + """ + device = self.device + pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges, + self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr, + device, has_im_poses=self.has_im_poses, verbose=self.verbose, + **kw) + + return init_from_pts3d(self, pts3d, im_focals, im_poses) + + +def init_from_pts3d(self, pts3d, im_focals, im_poses): + # init poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + if nkp == 1: + raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose") + elif nkp > 1: + # global rigid SE3 alignment + s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk]) + trf = sRT_to_4x4(s, R, T, device=known_poses.device) + + # rotate everything + im_poses = trf @ im_poses + im_poses[:, :3, :3] /= s # undo scaling on the rotation part + for img_pts3d in pts3d: + img_pts3d[:] = geotrf(trf, img_pts3d) + + # set all pairwise poses + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # compute transform that goes from cam to world + s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # take into account the scale normalization + s_factor = self.get_pw_norm_scale_factor() + im_poses[:, :3, 3] *= s_factor # apply downscaling factor + for img_pts3d in pts3d: + img_pts3d *= s_factor + + # init all image poses + if self.has_im_poses: + for i in range(self.n_imgs): + cam2world = im_poses[i] + depth = geotrf(inv(cam2world), pts3d[i])[..., 2] + self._set_depthmap(i, depth) + self._set_pose(self.im_poses, i, cam2world) + if im_focals[i] is not None: + self._set_focal(i, im_focals[i]) + + if self.verbose: + print(' init loss =', float(self())) + + +def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr, + device, has_im_poses=True, niter_PnP=10, verbose=True): + n_imgs = len(imshapes) + sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)) + msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() + + # temp variable to store 3d points + pts3d = [None] * len(imshapes) + + todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges + im_poses = [None] * n_imgs + im_focals = [None] * n_imgs + + # init with strongest edge + score, i, j = todo.pop() + if verbose: + print(f' init edge ({i}*,{j}*) {score=}') + i_j = edge_str(i, j) + pts3d[i] = pred_i[i_j].clone() + pts3d[j] = pred_j[i_j].clone() + done = {i, j} + if has_im_poses: + im_poses[i] = torch.eye(4, device=device) + im_focals[i] = estimate_focal(pred_i[i_j]) + + # set initial pointcloud based on pairwise graph + msp_edges = [(i, j)] + while todo: + # each time, predict the next one + score, i, j = todo.pop() + + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[i_j]) + + if i in done: + if verbose: + print(f' init edge ({i},{j}*) {score=}') + assert j not in done + # align pred[i] with pts3d[i], and then set j accordingly + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[j] = geotrf(trf, pred_j[i_j]) + done.add(j) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + + elif j in done: + if verbose: + print(f' init edge ({i}*,{j}) {score=}') + assert i not in done + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[i] = geotrf(trf, pred_i[i_j]) + done.add(i) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + else: + # let's try again later + todo.insert(0, (score, i, j)) + + if has_im_poses: + # complete all missing informations + pair_scores = list(sparse_graph.values()) # already negative scores: less is best + edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)] + for i, j in edges_from_best_to_worse.tolist(): + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[edge_str(i, j)]) + + for i in range(n_imgs): + if im_poses[i] is None: + msk = im_conf[i] > min_conf_thr + res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP) + if res: + im_focals[i], im_poses[i] = res + if im_poses[i] is None: + im_poses[i] = torch.eye(4, device=device) + im_poses = torch.stack(im_poses) + else: + im_poses = im_focals = None + + return pts3d, msp_edges, im_focals, im_poses + + +def dict_to_sparse_graph(dic): + n_imgs = max(max(e) for e in dic) + 1 + res = sp.dok_array((n_imgs, n_imgs)) + for edge, value in dic.items(): + res[edge] = value + return res + + +def rigid_points_registration(pts1, pts2, conf): + R, T, s = roma.rigid_points_registration( + pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True) + return s, R, T # return un-scaled (R, T) + + +def sRT_to_4x4(scale, R, T, device): + trf = torch.eye(4, device=device) + trf[:3, :3] = R * scale + trf[:3, 3] = T.ravel() # doesn't need scaling + return trf + + +def estimate_focal(pts3d_i, pp=None): + if pp is None: + H, W, THREE = pts3d_i.shape + assert THREE == 3 + pp = torch.tensor((W/2, H/2), device=pts3d_i.device) + focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel() + return float(focal) + + +@cache +def pixel_grid(H, W): + return np.mgrid[:W, :H].T.astype(np.float32) + + +def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): + # extract camera poses and focals with RANSAC-PnP + if msk.sum() < 4: + return None # we need at least 4 points for PnP + pts3d, msk = map(to_numpy, (pts3d, msk)) + + H, W, THREE = pts3d.shape + assert THREE == 3 + pixels = pixel_grid(H, W) + + if focal is None: + S = max(W, H) + tentative_focals = np.geomspace(S/2, S*3, 21) + else: + tentative_focals = [focal] + + if pp is None: + pp = (W/2, H/2) + else: + pp = to_numpy(pp) + + best = 0, + for focal in tentative_focals: + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + if not success: + continue + + score = len(inliers) + if success and score > best[0]: + best = score, R, T, focal + + if not best[0]: + return None + + _, R, T, best_focal = best + R = cv2.Rodrigues(R)[0] # world to cam + R, T = map(torch.from_numpy, (R, T)) + return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world + + +def get_known_poses(self): + if self.has_im_poses: + known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses]) + known_poses = self.get_im_poses() + return known_poses_msk.sum(), known_poses_msk, known_poses + else: + return 0, None, None + + +def get_known_focals(self): + if self.has_im_poses: + known_focal_msk = self.get_known_focal_mask() + known_focals = self.get_focals() + return known_focal_msk.sum(), known_focal_msk, known_focals + else: + return 0, None, None + + +def align_multiple_poses(src_poses, target_poses): + N = len(src_poses) + assert src_poses.shape == target_poses.shape == (N, 4, 4) + + def center_and_z(poses): + eps = get_med_dist_between_poses(poses) / 100 + return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2])) + R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True) + return s, R, T diff --git a/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/modular_optimizer.py b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/modular_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d06464b40276684385c18b9195be1491c6f47f07 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/modular_optimizer.py @@ -0,0 +1,145 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import geotrf +from dust3r.utils.device import to_cpu, to_numpy +from dust3r.utils.geometry import depthmap_to_pts3d + + +class ModularPointCloudOptimizer (BasePCOptimizer): + """ Optimize a global scene, given a list of pairwise observations. + Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics) + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs): + super().__init__(*args, **kwargs) + self.has_im_poses = True # by definition of this class + self.focal_brake = focal_brake + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) + self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses + default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes] + self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [ + f]) for f in default_focals) # camera intrinsics + self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f' (setting pose #{idx} = {pose[:3,3]})') + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True)) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = (n_known_poses <= 1) + + def preset_intrinsics(self, known_intrinsics, msk=None): + if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2: + known_intrinsics = [known_intrinsics] + for K in known_intrinsics: + assert K.shape == (3, 3) + self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk) + self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk) + + def preset_focal(self, known_focals, msk=None): + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f' (setting focal #{idx} = {focal})') + self._no_grad(self._set_focal(idx, focal, force=True)) + + def preset_principal_point(self, known_pp, msk=None): + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f' (setting principal point #{idx} = {pp})') + self._no_grad(self._set_principal_point(idx, pp, force=True)) + + def _no_grad(self, tensor): + return tensor.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f'bad {msk=}') + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = self.focal_brake * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_brake).exp() + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 + return param + + def get_principal_points(self): + return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)]) + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().view(self.n_imgs, -1) + K[:, 0, 0] = focals[:, 0] + K[:, 1, 1] = focals[:, -1] + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(torch.stack(list(self.im_poses))) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self): + return [d.exp() for d in self.im_depthmaps] + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps() + + # convert focal to (1,2,H,W) constant field + def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i]) + # get pointmaps in camera frame + rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])] + # project to world frame + return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)] + + def get_pts3d(self): + return self.depth_to_pts3d() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/optimizer.py b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..42e48613e55faa4ede5a366d1c0bfc4d18ffae4f --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/optimizer.py @@ -0,0 +1,248 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Main class for the implementation of the global alignment +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import xy_grid, geotrf +from dust3r.utils.device import to_cpu, to_numpy + + +class PointCloudOptimizer(BasePCOptimizer): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs): + super().__init__(*args, **kwargs) + + self.has_im_poses = True # by definition of this class + self.focal_break = focal_break + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) + self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses + self.im_focals = nn.ParameterList(torch.FloatTensor( + [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics + self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + self.imshape = self.imshapes[0] + im_areas = [h*w for h, w in self.imshapes] + self.max_area = max(im_areas) + + # adding thing to optimize + self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area) + self.im_poses = ParameterStack(self.im_poses, is_param=True) + self.im_focals = ParameterStack(self.im_focals, is_param=True) + self.im_pp = ParameterStack(self.im_pp, is_param=True) + self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes])) + self.register_buffer('_grid', ParameterStack( + [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area)) + + # pre-compute pixel weights + self.register_buffer('_weight_i', ParameterStack( + [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area)) + self.register_buffer('_weight_j', ParameterStack( + [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area)) + + # precompute aa + self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area)) + self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area)) + self.register_buffer('_ei', torch.tensor([i for i, j in self.edges])) + self.register_buffer('_ej', torch.tensor([j for i, j in self.edges])) + self.total_area_i = sum([im_areas[i] for i, j in self.edges]) + self.total_area_j = sum([im_areas[j] for i, j in self.edges]) + + def _check_all_imgs_are_selected(self, msk): + assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!' + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + self._check_all_imgs_are_selected(pose_msk) + + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f' (setting pose #{idx} = {pose[:3,3]})') + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = (n_known_poses <= 1) + + self.im_poses.requires_grad_(False) + self.norm_pw_scale = False + + def preset_focal(self, known_focals, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f' (setting focal #{idx} = {focal})') + self._no_grad(self._set_focal(idx, focal)) + + self.im_focals.requires_grad_(False) + + def preset_principal_point(self, known_pp, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f' (setting principal point #{idx} = {pp})') + self._no_grad(self._set_principal_point(idx, pp)) + + self.im_pp.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f'bad {msk=}') + + def _no_grad(self, tensor): + assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs' + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = self.focal_break * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_break).exp() + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.im_focals]) + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 + return param + + def get_principal_points(self): + return self._pp + 10 * self.im_pp + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().flatten() + K[:, 0, 0] = K[:, 1, 1] = focals + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(self.im_poses) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + depth = _ravel_hw(depth, self.max_area) + + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self, raw=False): + res = self.im_depthmaps.exp() + if not raw: + res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps(raw=True) + + # get pointmaps in camera frame + rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) + # project to world frame + return geotrf(im_poses, rel_ptmaps) + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def forward(self): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors().unsqueeze(1) + proj_pts3d = self.get_pts3d(raw=True) + + # rotate pairwise prediction according to pw_poses + aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) + aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) + + # compute the less + li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i + lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j + + return li + lj + + +def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): + pp = pp.unsqueeze(1) + focal = focal.unsqueeze(1) + assert focal.shape == (len(depth), 1, 1) + assert pp.shape == (len(depth), 1, 2) + assert pixel_grid.shape == depth.shape + (2,) + depth = depth.unsqueeze(-1) + return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) + + +def ParameterStack(params, keys=None, is_param=None, fill=0): + if keys is not None: + params = [params[k] for k in keys] + + if fill > 0: + params = [_ravel_hw(p, fill) for p in params] + + requires_grad = params[0].requires_grad + assert all(p.requires_grad == requires_grad for p in params) + + params = torch.stack(list(params)).float().detach() + if is_param or requires_grad: + params = nn.Parameter(params) + params.requires_grad_(requires_grad) + return params + + +def _ravel_hw(tensor, fill=0): + # ravel H,W + tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) + + if len(tensor) < fill: + tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:]))) + return tensor + + +def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + return minf*focal_base, maxf*focal_base + + +def apply_mask(img, msk): + img = img.copy() + img[msk] = 0 + return img diff --git a/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/pair_viewer.py b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/pair_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..62ae3b9a5fbca8b96711de051d9d6597830bd488 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/cloud_opt/pair_viewer.py @@ -0,0 +1,127 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dummy optimizer for visualizing pairs +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn +import cv2 + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates +from dust3r.cloud_opt.commons import edge_str +from dust3r.post_process import estimate_focal_knowing_depth + + +class PairViewer (BasePCOptimizer): + """ + This a Dummy Optimizer. + To use only when the goal is to visualize the results for a pair of images (with is_symmetrized) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.is_symmetrized and self.n_edges == 2 + self.has_im_poses = True + + # compute all parameters directly from raw input + self.focals = [] + self.pp = [] + rel_poses = [] + confs = [] + for i in range(self.n_imgs): + conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean()) + if self.verbose: + print(f' - {conf=:.3} for edge {i}-{1-i}') + confs.append(conf) + + H, W = self.imshapes[i] + pts3d = self.pred_i[edge_str(i, 1-i)] + pp = torch.tensor((W/2, H/2)) + focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld')) + self.focals.append(focal) + self.pp.append(pp) + + # estimate the pose of pts1 in image 2 + pixels = np.mgrid[:W, :H].T.astype(np.float32) + pts3d = self.pred_j[edge_str(1-i, i)].numpy() + assert pts3d.shape[:2] == (H, W) + msk = self.get_masks()[i].numpy() + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + try: + res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + success, R, T, inliers = res + assert success + + R = cv2.Rodrigues(R)[0] # world to cam + pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world + except: + pose = np.eye(4) + rel_poses.append(torch.from_numpy(pose.astype(np.float32))) + + # let's use the pair with the most confidence + if confs[0] > confs[1]: + # ptcloud is expressed in camera1 + self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1 + self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]] + else: + # ptcloud is expressed in camera2 + self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2 + self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]] + + self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False) + self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False) + self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False) + self.depth = nn.ParameterList(self.depth) + for p in self.parameters(): + p.requires_grad = False + + def _set_depthmap(self, idx, depth, force=False): + if self.verbose: + print('_set_depthmap is ignored in PairViewer') + return + + def get_depthmaps(self, raw=False): + depth = [d.to(self.device) for d in self.depth] + return depth + + def _set_focal(self, idx, focal, force=False): + self.focals[idx] = focal + + def get_focals(self): + return self.focals + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.focals]) + + def get_principal_points(self): + return self.pp + + def get_intrinsics(self): + focals = self.get_focals() + pps = self.get_principal_points() + K = torch.zeros((len(focals), 3, 3), device=self.device) + for i in range(len(focals)): + K[i, 0, 0] = K[i, 1, 1] = focals[i] + K[i, :2, 2] = pps[i] + K[i, 2, 2] = 1 + return K + + def get_im_poses(self): + return self.im_poses + + def depth_to_pts3d(self): + pts3d = [] + for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()): + pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(), + intrinsics.cpu().numpy(), + im_pose.cpu().numpy()) + pts3d.append(torch.from_numpy(pts).to(device=self.device)) + return pts3d + + def forward(self): + return float('nan') diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2123d09ec2840ab5ee9ca43057c35f93233bde89 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/__init__.py @@ -0,0 +1,50 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +from .utils.transforms import * +from .base.batched_sampler import BatchedRandomSampler # noqa +from .arkitscenes import ARKitScenes # noqa +from .blendedmvs import BlendedMVS # noqa +from .co3d import Co3d # noqa +from .habitat import Habitat # noqa +from .megadepth import MegaDepth # noqa +from .scannetpp import ScanNetpp # noqa +from .staticthings3d import StaticThings3D # noqa +from .waymo import Waymo # noqa +from .wildrgbd import WildRGBD # noqa + + +def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): + import torch + from croco.utils.misc import get_world_size, get_rank + + # pytorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + try: + sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, + rank=rank, drop_last=drop_last) + except (AttributeError, NotImplementedError): + # not avail for this dataset + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/arkitscenes.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/arkitscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..4fad51acdc18b82cd6a4d227de0dac3b25783e33 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/arkitscenes.py @@ -0,0 +1,102 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed arkitscenes +# dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license +# See datasets_preprocess/preprocess_arkitscenes.py +# -------------------------------------------------------- +import os.path as osp +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class ARKitScenes(BaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + if split == "train": + self.split = "Training" + elif split == "test": + self.split = "Test" + else: + raise ValueError("") + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + with np.load(osp.join(self.ROOT, split, 'all_metadata.npz')) as data: + self.scenes = data['scenes'] + self.sceneids = data['sceneids'] + self.images = data['images'] + self.intrinsics = data['intrinsics'].astype(np.float32) + self.trajectories = data['trajectories'].astype(np.float32) + self.pairs = data['pairs'][:, :2].astype(int) + + def __len__(self): + return len(self.pairs) + + def _get_views(self, idx, resolution, rng): + + image_idx1, image_idx2 = self.pairs[idx] + + views = [] + for view_idx in [image_idx1, image_idx2]: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, 'vga_wide', basename.replace('.png', '.jpg'))) + # Load depthmap + depthmap = imread_cv2(osp.join(scene_dir, 'lowres_depth', basename), cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='arkitscenes', + label=self.scenes[scene_id] + '_' + basename, + instance=f'{str(idx)}_{str(view_idx)}', + )) + + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = ARKitScenes(split='train', ROOT="data/arkitscenes_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17390ca29d4437fc41f3c946b235888af9e4c888 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py @@ -0,0 +1,220 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# base class for implementing datasets +# -------------------------------------------------------- +import PIL +import numpy as np +import torch + +from dust3r.datasets.base.easy_dataset import EasyDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates +import dust3r.datasets.utils.cropping as cropping + + +class BaseStereoViewDataset (EasyDataset): + """ Define all basic options. + + Usage: + class MyDataset (BaseStereoViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__(self, *, # only keyword arguments + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + seed=None): + self.num_views = 2 + self.split = split + self._set_resolutions(resolution) + + self.transform = transform + if isinstance(transform, str): + transform = eval(transform) + + self.aug_crop = aug_crop + self.seed = seed + + def __len__(self): + return len(self.scenes) + + def get_stats(self): + return f"{len(self)} pairs" + + def __repr__(self): + resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']' + return f"""{type(self).__name__}({self.get_stats()}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '') + + def _get_views(self, idx, resolution, rng): + raise NotImplementedError() + + def __getitem__(self, idx): + if isinstance(idx, tuple): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, '_rng'): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # over-loaded code + resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng) + assert len(views) == self.num_views + + # check data-types + for v, view in enumerate(views): + assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view['idx'] = (idx, ar_idx, v) + + # encode the image + width, height = view['img'].size + view['true_shape'] = np.int32((height, width)) + view['img'] = self.transform(view['img']) + + assert 'camera_intrinsics' in view + if 'camera_pose' not in view: + view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}' + assert 'pts3d' not in view + assert 'valid_mask' not in view + assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}' + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view['pts3d'] = pts3d + view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view['camera_intrinsics'] + + # last thing done! + for view in views: + # transpose to make sure all views are the same size + transpose_to_landscape(view) + # this allows to check whether the RNG is is the same state each time + view['rng'] = int.from_bytes(self._rng.bytes(4), 'big') + return views + + def _set_resolutions(self, resolutions): + assert resolutions is not None, 'undefined resolution' + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int' + assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int' + assert width >= height + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None): + """ This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W-cx) + min_margin_y = min(cy, H-cy) + assert min_margin_x > W/5, f'Bad principal point in view={info}' + assert min_margin_y > H/5, f'Bad principal point in view={info}' + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + # transpose the resolution if necessary + W, H = image.size # new size + assert resolution[0] >= resolution[1] + if H > 1.1*W: + # image is portrait mode + resolution = resolution[::-1] + elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2): + resolution = resolution[::-1] + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + if self.aug_crop > 1: + target_resolution += rng.integers(0, self.aug_crop) + image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) + + # actual cropping (if necessary) with bilinear interpolation + intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5) + crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + return image, depthmap, intrinsics2 + + +def is_good_type(key, v): + """ returns (is_good, err_msg) + """ + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x + db = sel(view['dataset']) + label = sel(view['label']) + instance = sel(view['instance']) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view['true_shape'] + + if width < height: + # rectify portrait to landscape + assert view['img'].shape == (3, height, width) + view['img'] = view['img'].swapaxes(1, 2) + + assert view['valid_mask'].shape == (height, width) + view['valid_mask'] = view['valid_mask'].swapaxes(0, 1) + + assert view['depthmap'].shape == (height, width) + view['depthmap'] = view['depthmap'].swapaxes(0, 1) + + assert view['pts3d'].shape == (height, width, 3) + view['pts3d'] = view['pts3d'].swapaxes(0, 1) + + # transpose x and y pixels + view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]] diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/batched_sampler.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..85f58a65d41bb8101159e032d5b0aac26a7cf1a1 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/batched_sampler.py @@ -0,0 +1,74 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Random sampling under a constraint +# -------------------------------------------------------- +import numpy as np +import torch + + +class BatchedRandomSampler: + """ Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): + self.batch_size = batch_size + self.pool_size = pool_size + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size*world_size) if drop_last else N + assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' + + # distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + return self.total_size // self.world_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + # prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # random feat_idxs (same across each batch) + n_batches = (self.total_size+self.batch_size-1) // self.batch_size + feat_idxs = rng.integers(self.pool_size, size=n_batches) + feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) + feat_idxs = feat_idxs.ravel()[:self.total_size] + + # put them together + idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) + + # Distributed sampler: we select a subset of batches + # make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ((self.total_size + self.world_size * + self.batch_size-1) // (self.world_size * self.batch_size)) + idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple-1 + return (total//multiple) * multiple diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/easy_dataset.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4939a88f02715a1f80be943ddb6d808e1be84db7 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/base/easy_dataset.py @@ -0,0 +1,157 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# A dataset base class that you can easily resize and combine. +# -------------------------------------------------------- +import numpy as np +from dust3r.datasets.base.batched_sampler import BatchedRandomSampler + + +class EasyDataset: + """ a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) + + +class MulDataset (EasyDataset): + """ Artifically augmenting the size of a dataset. + """ + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f'{self.multiplicator}*{repr(self.dataset)}' + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[idx // self.multiplicator, other] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class ResizedDataset (EasyDataset): + """ Artifically changing the size of a dataset. + """ + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str)-1) // 3): + sep = -4*i-3 + size_str = size_str[:sep] + '_' + size_str[sep:] + return f'{size_str} @ {repr(self.dataset)}' + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch+777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) + self._idxs_mapping = shuffled_idxs[:self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def __getitem__(self, idx): + assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[self._idxs_mapping[idx], other] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class CatDataset (EasyDataset): + """ Concatenation of several datasets + """ + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other = idx + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, 'right') + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None: + new_idx = (new_idx, other) + return dataset[new_idx] + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/blendedmvs.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/blendedmvs.py new file mode 100644 index 0000000000000000000000000000000000000000..93e68c28620cc47a7b1743834e45f82d576126d0 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/blendedmvs.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed BlendedMVS +# dataset at https://github.com/YoYo000/BlendedMVS +# See datasets_preprocess/preprocess_blendedmvs.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class BlendedMVS (BaseStereoViewDataset): + """ Dataset of outdoor street scenes, 5 images each time + """ + + def __init__(self, *args, ROOT, split=None, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self._load_data(split) + + def _load_data(self, split): + pairs = np.load(osp.join(self.ROOT, 'blendedmvs_pairs.npy')) + if split is None: + selection = slice(None) + if split == 'train': + # select 90% of all scenes + selection = (pairs['seq_low'] % 10) > 0 + if split == 'val': + # select 10% of all scenes + selection = (pairs['seq_low'] % 10) == 0 + self.pairs = pairs[selection] + + # list of all scenes + self.scenes = np.unique(self.pairs['seq_low']) # low is unique enough + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.scenes)} scenes' + + def _get_views(self, pair_idx, resolution, rng): + seqh, seql, img1, img2, score = self.pairs[pair_idx] + + seq = f"{seqh:08x}{seql:016x}" + seq_path = osp.join(self.ROOT, seq) + + views = [] + + for view_index in [img1, img2]: + impath = f"{view_index:08n}" + image = imread_cv2(osp.join(seq_path, impath + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) + camera_params = np.load(osp.join(seq_path, impath + ".npz")) + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params['R_cam2world'] + camera_pose[:3, 3] = camera_params['t_cam2world'] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='BlendedMVS', + label=osp.relpath(seq_path, self.ROOT), + instance=impath)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = BlendedMVS(split='train', ROOT="data/blendedmvs_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/co3d.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/co3d.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea5c8555d34b776e7a48396dcd0eecece713e34 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/co3d.py @@ -0,0 +1,165 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed Co3d_v2 +# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International +# See datasets_preprocess/preprocess_co3d.py +# -------------------------------------------------------- +import os.path as osp +import json +import itertools +from collections import deque + +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class Co3d(BaseStereoViewDataset): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert mask_bg in (True, False, 'rand') + self.mask_bg = mask_bg + self.dataset_label = 'Co3d_v2' + + # load all scenes + with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f: + self.scenes = json.load(f) + self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} + self.scenes = {(k, k2): v2 for k, v in self.scenes.items() + for k2, v2 in v.items()} + self.scene_list = list(self.scenes.keys()) + + # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) + # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees + self.combinations = [(i, j) + for i, j in itertools.combinations(range(100), 2) + if 0 < abs(i - j) <= 30 and abs(i - j) % 5 == 0] + + self.invalidate = {scene: {} for scene in self.scene_list} + + def __len__(self): + return len(self.scene_list) * len(self.combinations) + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz') + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png') + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') + + def _read_depthmap(self, depthpath, input_metadata): + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth']) + return depthmap + + def _get_views(self, idx, resolution, rng): + # choose a scene + obj, instance = self.scene_list[idx // len(self.combinations)] + image_pool = self.scenes[obj, instance] + im1_idx, im2_idx = self.combinations[idx % len(self.combinations)] + + # add a bit of randomness + last = len(image_pool) - 1 + + if resolution not in self.invalidate[obj, instance]: # flag invalid images + self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] + + # decide now if we mask the bg + mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) + + views = [] + imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]] + imgs_idxs = deque(imgs_idxs) + while len(imgs_idxs) > 0: # some images (few) have zero depth + im_idx = imgs_idxs.pop() + + if self.invalidate[obj, instance][resolution][im_idx]: + # search for a valid image + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(image_pool)): + tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) + if not self.invalidate[obj, instance][resolution][tentative_im_idx]: + im_idx = tentative_im_idx + break + + view_idx = image_pool[im_idx] + + impath = self._get_impath(obj, instance, view_idx) + depthpath = self._get_depthpath(obj, instance, view_idx) + + # load camera params + metadata_path = self._get_metadatapath(obj, instance, view_idx) + input_metadata = np.load(metadata_path) + camera_pose = input_metadata['camera_pose'].astype(np.float32) + intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = self._read_depthmap(depthpath, input_metadata) + + if mask_bg: + # load object mask + maskpath = self._get_maskpath(obj, instance, view_idx) + maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) + maskmap = (maskmap / 255.0) > 0.1 + + # update the depthmap with mask + depthmap *= maskmap + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) + + num_valid = (depthmap > 0.0).sum() + if num_valid == 0: + # problem, invalidate image and retry + self.invalidate[obj, instance][resolution][im_idx] = True + imgs_idxs.append(im_idx) + continue + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/habitat.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/habitat.py new file mode 100644 index 0000000000000000000000000000000000000000..11ce8a0ffb2134387d5fb794df89834db3ea8c9f --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/habitat.py @@ -0,0 +1,107 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed habitat +# dataset at https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md +# See datasets_preprocess/habitat for more details +# -------------------------------------------------------- +import os.path as osp +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa +import cv2 # noqa +import numpy as np +from PIL import Image +import json + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset + + +class Habitat(BaseStereoViewDataset): + def __init__(self, size, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert self.split is not None + # loading list of scenes + with open(osp.join(self.ROOT, f'Habitat_{size}_scenes_{self.split}.txt')) as f: + self.scenes = f.read().splitlines() + self.instances = list(range(1, 5)) + + def filter_scene(self, label, instance=None): + if instance: + subscene, instance = instance.split('_') + label += '/' + subscene + self.instances = [int(instance) - 1] + valid = np.bool_([scene.startswith(label) for scene in self.scenes]) + assert sum(valid), 'no scene was selected for {label=} {instance=}' + self.scenes = [scene for i, scene in enumerate(self.scenes) if valid[i]] + + def _get_views(self, idx, resolution, rng): + scene = self.scenes[idx] + data_path, key = osp.split(osp.join(self.ROOT, scene)) + views = [] + two_random_views = [0, rng.choice(self.instances)] # view 0 is connected with all other views + for view_index in two_random_views: + # load the view (and use the next one if this one's broken) + for ii in range(view_index, view_index + 5): + image, depthmap, intrinsics, camera_pose = self._load_one_view(data_path, key, ii % 5, resolution, rng) + if np.isfinite(camera_pose).all(): + break + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='Habitat', + label=osp.relpath(data_path, self.ROOT), + instance=f"{key}_{view_index}")) + return views + + def _load_one_view(self, data_path, key, view_index, resolution, rng): + view_index += 1 # file indices starts at 1 + impath = osp.join(data_path, f"{key}_{view_index}.jpeg") + image = Image.open(impath) + + depthmap_filename = osp.join(data_path, f"{key}_{view_index}_depth.exr") + depthmap = cv2.imread(depthmap_filename, cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH) + + camera_params_filename = osp.join(data_path, f"{key}_{view_index}_camera_params.json") + with open(camera_params_filename, 'r') as f: + camera_params = json.load(f) + + intrinsics = np.float32(camera_params['camera_intrinsics']) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params['R_cam2world'] + camera_pose[:3, 3] = camera_params['t_cam2world'] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=impath) + return image, depthmap, intrinsics, camera_pose + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Habitat(1_000_000, split='train', ROOT="data/habitat_processed", + resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/megadepth.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..8131498b76d855e5293fe79b3686fc42bf87eea8 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/megadepth.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed MegaDepth +# dataset at https://www.cs.cornell.edu/projects/megadepth/ +# See datasets_preprocess/preprocess_megadepth.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class MegaDepth(BaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data(self.split) + + if self.split is None: + pass + elif self.split == 'train': + self.select_scene(('0015', '0022'), opposite=True) + elif self.split == 'val': + self.select_scene(('0015', '0022')) + else: + raise ValueError(f'bad {self.split=}') + + def _load_data(self, split): + with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: + self.all_scenes = data['scenes'] + self.all_images = data['images'] + self.pairs = data['pairs'] + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.all_scenes)} scenes' + + def select_scene(self, scene, *instances, opposite=False): + scenes = (scene,) if isinstance(scene, str) else tuple(scene) + scene_id = [s.startswith(scenes) for s in self.all_scenes] + assert any(scene_id), 'no scene found' + + valid = np.in1d(self.pairs['scene_id'], np.nonzero(scene_id)[0]) + if instances: + image_id = [i.startswith(instances) for i in self.all_images] + image_id = np.nonzero(image_id)[0] + assert len(image_id), 'no instance found' + # both together? + if len(instances) == 2: + valid &= np.in1d(self.pairs['im1_id'], image_id) & np.in1d(self.pairs['im2_id'], image_id) + else: + valid &= np.in1d(self.pairs['im1_id'], image_id) | np.in1d(self.pairs['im2_id'], image_id) + + if opposite: + valid = ~valid + assert valid.any() + self.pairs = self.pairs[valid] + + def _get_views(self, pair_idx, resolution, rng): + scene_id, im1_id, im2_id, score = self.pairs[pair_idx] + + scene, subscene = self.all_scenes[scene_id].split() + seq_path = osp.join(self.ROOT, scene, subscene) + + views = [] + + for im_id in [im1_id, im2_id]: + img = self.all_images[im_id] + try: + image = imread_cv2(osp.join(seq_path, img + '.jpg')) + depthmap = imread_cv2(osp.join(seq_path, img + ".exr")) + camera_params = np.load(osp.join(seq_path, img + ".npz")) + except Exception as e: + raise OSError(f'cannot load {img}, got exception {e}') + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.float32(camera_params['cam2world']) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, img)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='MegaDepth', + label=osp.relpath(seq_path, self.ROOT), + instance=img)) + + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = MegaDepth(split='train', ROOT="data/megadepth_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/scannetpp.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..520deedd0eb8cba8663af941731d89e0b2e71a80 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/scannetpp.py @@ -0,0 +1,96 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed scannet++ +# dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes +# https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf +# See datasets_preprocess/preprocess_scannetpp.py +# -------------------------------------------------------- +import os.path as osp +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class ScanNetpp(BaseStereoViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert self.split == 'train' + self.loaded_data = self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: + self.scenes = data['scenes'] + self.sceneids = data['sceneids'] + self.images = data['images'] + self.intrinsics = data['intrinsics'].astype(np.float32) + self.trajectories = data['trajectories'].astype(np.float32) + self.pairs = data['pairs'][:, :2].astype(int) + + def __len__(self): + return len(self.pairs) + + def _get_views(self, idx, resolution, rng): + + image_idx1, image_idx2 = self.pairs[idx] + + views = [] + for view_idx in [image_idx1, image_idx2]: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename + '.jpg')) + # Load depthmap + depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename + '.png'), cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='ScanNet++', + label=self.scenes[scene_id] + '_' + basename, + instance=f'{str(idx)}_{str(view_idx)}', + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = ScanNetpp(split='train', ROOT="data/scannetpp_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx*255, (1 - idx)*255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/staticthings3d.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/staticthings3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f70f0ee7bf8c8ab6bb1702aa2481f3d16df413 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/staticthings3d.py @@ -0,0 +1,96 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed StaticThings3D +# dataset at https://github.com/lmb-freiburg/robustmvd/ +# See datasets_preprocess/preprocess_staticthings3d.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class StaticThings3D (BaseStereoViewDataset): + """ Dataset of indoor scenes, 5 images each time + """ + def __init__(self, ROOT, *args, mask_bg='rand', **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + + assert mask_bg in (True, False, 'rand') + self.mask_bg = mask_bg + + # loading all pairs + assert self.split is None + self.pairs = np.load(osp.join(ROOT, 'staticthings_pairs.npy')) + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs' + + def _get_views(self, pair_idx, resolution, rng): + scene, seq, cam1, im1, cam2, im2 = self.pairs[pair_idx] + seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') + + views = [] + + mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) + + CAM = {b'l':'left', b'r':'right'} + for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: + num = f"{idx:04n}" + img = num+"_clean.jpg" if rng.choice(2) else num+"_final.jpg" + image = imread_cv2(osp.join(self.ROOT, seq_path, cam, img)) + depthmap = imread_cv2(osp.join(self.ROOT, seq_path, cam, num+".exr")) + camera_params = np.load(osp.join(self.ROOT, seq_path, cam, num+".npz")) + + intrinsics = camera_params['intrinsics'] + camera_pose = camera_params['cam2world'] + + if mask_bg: + depthmap[depthmap > 200] = 0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng, info=(seq_path,cam,img)) + + views.append(dict( + img = image, + depthmap = depthmap, + camera_pose = camera_pose, # cam2world + camera_intrinsics = intrinsics, + dataset = 'StaticThings3D', + label = seq_path, + instance = cam+'_'+img)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = StaticThings3D(ROOT="data/staticthings3d_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx*255, (1 - idx)*255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/cropping.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb4eaa92d21d0ecb8473faa60e5fc13ddf317e3 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/cropping.py @@ -0,0 +1,124 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """ Convenience class to aply the same operation to a whole set of images. + """ + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch('resize', *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch('crop', *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True): + """ Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + # define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize(output_resolution, resample=lanczos if scale_final < 1 else bicubic) + if depthmap is not None: + depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, + fy=scale_final, interpolation=cv2.INTER_NEAREST) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/transforms.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..eb34f2f01d3f8f829ba71a7e03e181bf18f72c25 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/utils/transforms.py @@ -0,0 +1,11 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf +from dust3r.utils.image import ImgNorm + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/waymo.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a135152cd8973532405b491450c22942dcd6ca --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/waymo.py @@ -0,0 +1,93 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed WayMo +# dataset at https://github.com/waymo-research/waymo-open-dataset +# See datasets_preprocess/preprocess_waymo.py +# -------------------------------------------------------- +import os.path as osp +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class Waymo (BaseStereoViewDataset): + """ Dataset of outdoor street scenes, 5 images each time + """ + + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self._load_data() + + def _load_data(self): + with np.load(osp.join(self.ROOT, 'waymo_pairs.npz')) as data: + self.scenes = data['scenes'] + self.frames = data['frames'] + self.inv_frames = {frame: i for i, frame in enumerate(data['frames'])} + self.pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) + assert self.pairs[:, 0].max() == len(self.scenes) - 1 + + def __len__(self): + return len(self.pairs) + + def get_stats(self): + return f'{len(self)} pairs from {len(self.scenes)} scenes' + + def _get_views(self, pair_idx, resolution, rng): + seq, img1, img2 = self.pairs[pair_idx] + seq_path = osp.join(self.ROOT, self.scenes[seq]) + + views = [] + + for view_index in [img1, img2]: + impath = self.frames[view_index] + image = imread_cv2(osp.join(seq_path, impath + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) + camera_params = np.load(osp.join(seq_path, impath + ".npz")) + + intrinsics = np.float32(camera_params['intrinsics']) + camera_pose = np.float32(camera_params['cam2world']) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) + + views.append(dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset='Waymo', + label=osp.relpath(seq_path, self.ROOT), + instance=impath)) + + return views + + +if __name__ == '__main__': + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Waymo(split='train', ROOT="data/megadepth_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(idx, view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/datasets/wildrgbd.py b/imcui/third_party/mast3r/dust3r/dust3r/datasets/wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..c41dd0b78402bf8ff1e62c6a50de338aa916e0af --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/datasets/wildrgbd.py @@ -0,0 +1,67 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed WildRGB-D +# dataset at https://github.com/wildrgbd/wildrgbd/ +# See datasets_preprocess/preprocess_wildrgbd.py +# -------------------------------------------------------- +import os.path as osp + +import cv2 +import numpy as np + +from dust3r.datasets.co3d import Co3d +from dust3r.utils.image import imread_cv2 + + +class WildRGBD(Co3d): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = 'WildRGBD' + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'metadata', f'{view_idx:0>5d}.npz') + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'rgb', f'{view_idx:0>5d}.jpg') + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'depth', f'{view_idx:0>5d}.png') + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, 'masks', f'{view_idx:0>5d}.png') + + def _read_depthmap(self, depthpath, input_metadata): + # We store depths in the depth scale of 1000. + # That is, when we load depth image and divide by 1000, we could get depth in meters. + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000.0 + return depthmap + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = WildRGBD(split='train', ROOT="data/wildrgbd_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx * 255, (1 - idx) * 255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/mast3r/dust3r/dust3r/demo.py b/imcui/third_party/mast3r/dust3r/dust3r/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..c491be097b71ec38ea981dadf4f456d6e9829d48 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/demo.py @@ -0,0 +1,283 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# gradio demo +# -------------------------------------------------------- +import argparse +import math +import builtins +import datetime +import gradio +import os +import torch +import numpy as np +import functools +import trimesh +import copy +from scipy.spatial.transform import Rotation + +from dust3r.inference import inference +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images, rgb +from dust3r.utils.device import to_numpy +from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes +from dust3r.cloud_opt import global_aligner, GlobalAlignerMode + +import matplotlib.pyplot as pl + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser_url = parser.add_mutually_exclusive_group() + parser_url.add_argument("--local_network", action='store_true', default=False, + help="make app accessible on local network: address will be set to 0.0.0.0") + parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1") + parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") + parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). " + "If None, will search for an available port starting at 7860."), + default=None) + parser_weights = parser.add_mutually_exclusive_group(required=True) + parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) + parser_weights.add_argument("--model_name", type=str, help="name of the model weights", + choices=["DUSt3R_ViTLarge_BaseDecoder_512_dpt", + "DUSt3R_ViTLarge_BaseDecoder_512_linear", + "DUSt3R_ViTLarge_BaseDecoder_224_linear"]) + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir") + parser.add_argument("--silent", action='store_true', default=False, + help="silence logs") + return parser + + +def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"): + builtin_print = builtins.print + + def print_with_timestamp(*args, **kwargs): + now = datetime.datetime.now() + formatted_date_time = now.strftime(time_format) + + builtin_print(f'[{formatted_date_time}] ', end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print_with_timestamp + + +def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, + cam_color=None, as_pointcloud=False, + transparent_cams=False, silent=False): + assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) + pts3d = to_numpy(pts3d) + imgs = to_numpy(imgs) + focals = to_numpy(focals) + cams2world = to_numpy(cams2world) + + scene = trimesh.Scene() + + # full pointcloud + if as_pointcloud: + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) + pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) + scene.add_geometry(pct) + else: + meshes = [] + for i in range(len(imgs)): + meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i])) + mesh = trimesh.Trimesh(**cat_meshes(meshes)) + scene.add_geometry(mesh) + + # add each camera + for i, pose_c2w in enumerate(cams2world): + if isinstance(cam_color, list): + camera_edge_color = cam_color[i] + else: + camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] + add_scene_cam(scene, pose_c2w, camera_edge_color, + None if transparent_cams else imgs[i], focals[i], + imsize=imgs[i].shape[1::-1], screen_width=cam_size) + + rot = np.eye(4) + rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() + scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) + outfile = os.path.join(outdir, 'scene.glb') + if not silent: + print('(exporting 3D scene to', outfile, ')') + scene.export(file_obj=outfile) + return outfile + + +def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, + clean_depth=False, transparent_cams=False, cam_size=0.05): + """ + extract 3D_model (glb file) from a reconstructed scene + """ + if scene is None: + return None + # post processes + if clean_depth: + scene = scene.clean_pointcloud() + if mask_sky: + scene = scene.mask_sky() + + # get optimized values from scene + rgbimg = scene.imgs + focals = scene.get_focals().cpu() + cams2world = scene.get_im_poses().cpu() + # 3D pointcloud from depthmap, poses and intrinsics + pts3d = to_numpy(scene.get_pts3d()) + scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr))) + msk = to_numpy(scene.get_masks()) + return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, + transparent_cams=transparent_cams, cam_size=cam_size, silent=silent) + + +def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, + scenegraph_type, winsize, refid): + """ + from a list of images, run dust3r inference, global aligner. + then run get_3D_model_from_scene + """ + imgs = load_images(filelist, size=image_size, verbose=not silent) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + if scenegraph_type == "swin": + scenegraph_type = scenegraph_type + "-" + str(winsize) + elif scenegraph_type == "oneref": + scenegraph_type = scenegraph_type + "-" + str(refid) + + pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) + output = inference(pairs, model, device, batch_size=1, verbose=not silent) + + mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer + scene = global_aligner(output, device=device, mode=mode, verbose=not silent) + lr = 0.01 + + if mode == GlobalAlignerMode.PointCloudOptimizer: + loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) + + outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size) + + # also return rgb, depth and confidence imgs + # depth is normalized with the max value for all images + # we apply the jet colormap on the confidence maps + rgbimg = scene.imgs + depths = to_numpy(scene.get_depthmaps()) + confs = to_numpy([c for c in scene.im_conf]) + cmap = pl.get_cmap('jet') + depths_max = max([d.max() for d in depths]) + depths = [d / depths_max for d in depths] + confs_max = max([d.max() for d in confs]) + confs = [cmap(d / confs_max) for d in confs] + + imgs = [] + for i in range(len(rgbimg)): + imgs.append(rgbimg[i]) + imgs.append(rgb(depths[i])) + imgs.append(rgb(confs[i])) + + return scene, outfile, imgs + + +def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type): + num_files = len(inputfiles) if inputfiles is not None else 1 + max_winsize = max(1, math.ceil((num_files - 1) / 2)) + if scenegraph_type == "swin": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=True) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files - 1, step=1, visible=False) + elif scenegraph_type == "oneref": + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files - 1, step=1, visible=True) + else: + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=1, maximum=max_winsize, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files - 1, step=1, visible=False) + return winsize, refid + + +def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False): + recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size) + model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent) + with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo: + # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference + scene = gradio.State(None) + gradio.HTML('

DUSt3R Demo

') + with gradio.Column(): + inputfiles = gradio.File(file_count="multiple") + with gradio.Row(): + schedule = gradio.Dropdown(["linear", "cosine"], + value='linear', label="schedule", info="For global alignment!") + niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000, + label="num_iterations", info="For global alignment!") + scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"), + ("swin: sliding window", "swin"), + ("oneref: match one image with all", "oneref")], + value='complete', label="Scenegraph", + info="Define how to make pairs", + interactive=True) + winsize = gradio.Slider(label="Scene Graph: Window Size", value=1, + minimum=1, maximum=1, step=1, visible=False) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) + + run_btn = gradio.Button("Run") + + with gradio.Row(): + # adjust the confidence threshold + min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1) + # adjust the camera size in the output pointcloud + cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001) + with gradio.Row(): + as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud") + # two post process implemented + mask_sky = gradio.Checkbox(value=False, label="Mask sky") + clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") + transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") + + outmodel = gradio.Model3D() + outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%") + + # events + scenegraph_type.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + inputfiles.change(set_scenegraph_options, + inputs=[inputfiles, winsize, refid, scenegraph_type], + outputs=[winsize, refid]) + run_btn.click(fn=recon_fun, + inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, + mask_sky, clean_depth, transparent_cams, cam_size, + scenegraph_type, winsize, refid], + outputs=[scene, outmodel, outgallery]) + min_conf_thr.release(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + cam_size.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + as_pointcloud.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + mask_sky.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + clean_depth.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + transparent_cams.change(model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size], + outputs=outmodel) + demo.launch(share=False, server_name=server_name, server_port=server_port) diff --git a/imcui/third_party/mast3r/dust3r/dust3r/heads/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53d0aa5610cae95f34f96bdb3ff9e835a2d6208e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/heads/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# head factory +# -------------------------------------------------------- +from .linear_head import LinearPts3d +from .dpt_head import create_dpt_head + + +def head_factory(head_type, output_mode, net, has_conf=False): + """" build a prediction head for the decoder + """ + if head_type == 'linear' and output_mode == 'pts3d': + return LinearPts3d(net, has_conf) + elif head_type == 'dpt' and output_mode == 'pts3d': + return create_dpt_head(net, has_conf=has_conf) + else: + raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") diff --git a/imcui/third_party/mast3r/dust3r/dust3r/heads/dpt_head.py b/imcui/third_party/mast3r/dust3r/dust3r/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b7bdc9ff587eef3ec8978a22f63659fbf3c277d6 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/heads/dpt_head.py @@ -0,0 +1,115 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dpt head implementation for DUST3R +# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; +# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True +# the forward function also takes as input a dictionnary img_info with key "height" and "width" +# for PixelwiseTask, the output will be of dimension B x num_channels x H x W +# -------------------------------------------------------- +from einops import rearrange +from typing import List +import torch +import torch.nn as nn +from dust3r.heads.postprocess import postprocess +import dust3r.utils.path_to_croco # noqa: F401 +from models.dpt_block import DPTOutputAdapter # noqa + + +class DPTOutputAdapter_fix(DPTOutputAdapter): + """ + Adapt croco's DPTOutputAdapter implementation for dust3r: + remove duplicated weigths, and fix forward for dust3r + """ + + def init(self, dim_tokens_enc=768): + super().init(dim_tokens_enc) + # these are duplicated weights + del self.act_1_postprocess + del self.act_2_postprocess + del self.act_3_postprocess + del self.act_4_postprocess + + def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + # H, W = input_info['image_size'] + image_size = self.image_size if image_size is None else image_size + H, W = image_size + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for dust3r, can return 3D points + confidence for all pixels""" + + def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, + output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_layers = True # backbone needs to return all layers + self.postprocess = postprocess + self.depth_mode = depth_mode + self.conf_mode = conf_mode + + assert n_cls_token == 0, "Not implemented" + dpt_args = dict(output_width_ratio=output_width_ratio, + num_channels=num_channels, + **kwargs) + if hooks_idx is not None: + dpt_args.update(hooks=hooks_idx) + self.dpt = DPTOutputAdapter_fix(**dpt_args) + dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info[0], img_info[1])) + if self.postprocess: + out = self.postprocess(out, self.depth_mode, self.conf_mode) + return out + + +def create_dpt_head(net, has_conf=False): + """ + return PixelwiseTaskWithDPT for given net params + """ + assert net.dec_depth > 9 + l2 = net.dec_depth + feature_dim = 256 + last_dim = feature_dim//2 + out_nchan = 3 + ed = net.enc_embed_dim + dd = net.dec_embed_dim + return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=[0, l2*2//4, l2*3//4, l2], + dim_tokens=[ed, dd, dd, dd], + postprocess=postprocess, + depth_mode=net.depth_mode, + conf_mode=net.conf_mode, + head_type='regression') diff --git a/imcui/third_party/mast3r/dust3r/dust3r/heads/linear_head.py b/imcui/third_party/mast3r/dust3r/dust3r/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6b697f29eaa6f43fad0a3e27a8d9b8f1a602a833 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/heads/linear_head.py @@ -0,0 +1,41 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# linear head implementation for DUST3R +# -------------------------------------------------------- +import torch.nn as nn +import torch.nn.functional as F +from dust3r.heads.postprocess import postprocess + + +class LinearPts3d (nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, net, has_conf=False): + super().__init__() + self.patch_size = net.patch_embed.patch_size[0] + self.depth_mode = net.depth_mode + self.conf_mode = net.conf_mode + self.has_conf = has_conf + + self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) + + def setup(self, croconet): + pass + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return postprocess(feat, self.depth_mode, self.conf_mode) diff --git a/imcui/third_party/mast3r/dust3r/dust3r/heads/postprocess.py b/imcui/third_party/mast3r/dust3r/dust3r/heads/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..cd68a90d89b8dcd7d8a4b4ea06ef8b17eb5da093 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/heads/postprocess.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# post process function for all heads: extract 3D points/confidence from output +# -------------------------------------------------------- +import torch + + +def postprocess(out, depth_mode, conf_mode): + """ + extract 3D points/confidence from prediction head output + """ + fmap = out.permute(0, 2, 3, 1) # B,H,W,3 + res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) + + if conf_mode is not None: + res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) + return res + + +def reg_dense_depth(xyz, mode): + """ + extract 3D points from prediction head output + """ + mode, vmin, vmax = mode + + no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) + assert no_bounds + + if mode == 'linear': + if no_bounds: + return xyz # [-inf, +inf] + return xyz.clip(min=vmin, max=vmax) + + # distance to origin + d = xyz.norm(dim=-1, keepdim=True) + xyz = xyz / d.clip(min=1e-8) + + if mode == 'square': + return xyz * d.square() + + if mode == 'exp': + return xyz * torch.expm1(d) + + raise ValueError(f'bad {mode=}') + + +def reg_dense_conf(x, mode): + """ + extract confidence from prediction head output + """ + mode, vmin, vmax = mode + if mode == 'exp': + return vmin + x.exp().clip(max=vmax-vmin) + if mode == 'sigmoid': + return (vmax - vmin) * torch.sigmoid(x) + vmin + raise ValueError(f'bad {mode=}') diff --git a/imcui/third_party/mast3r/dust3r/dust3r/image_pairs.py b/imcui/third_party/mast3r/dust3r/dust3r/image_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..ebcf902b4d07b83fe83ffceba3f45ca0d74dfcf7 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/image_pairs.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed to load image pairs +# -------------------------------------------------------- +import numpy as np +import torch + + +def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True): + pairs = [] + if scene_graph == 'complete': # complete graph + for i in range(len(imgs)): + for j in range(i): + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('swin'): + iscyclic = not scene_graph.endswith('noncyclic') + try: + winsize = int(scene_graph.split('-')[1]) + except Exception as e: + winsize = 3 + pairsid = set() + for i in range(len(imgs)): + for j in range(1, winsize + 1): + idx = (i + j) + if iscyclic: + idx = idx % len(imgs) # explicit loop closure + if idx >= len(imgs): + continue + pairsid.add((i, idx) if i < idx else (idx, i)) + for i, j in pairsid: + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('logwin'): + iscyclic = not scene_graph.endswith('noncyclic') + try: + winsize = int(scene_graph.split('-')[1]) + except Exception as e: + winsize = 3 + offsets = [2**i for i in range(winsize)] + pairsid = set() + for i in range(len(imgs)): + ixs_l = [i - off for off in offsets] + ixs_r = [i + off for off in offsets] + for j in ixs_l + ixs_r: + if iscyclic: + j = j % len(imgs) # Explicit loop closure + if j < 0 or j >= len(imgs) or j == i: + continue + pairsid.add((i, j) if i < j else (j, i)) + for i, j in pairsid: + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith('oneref'): + refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 + for j in range(len(imgs)): + if j != refid: + pairs.append((imgs[refid], imgs[j])) + if symmetrize: + pairs += [(img2, img1) for img1, img2 in pairs] + + # now, remove edges + if isinstance(prefilter, str) and prefilter.startswith('seq'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:])) + + if isinstance(prefilter, str) and prefilter.startswith('cyc'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) + + return pairs + + +def sel(x, kept): + if isinstance(x, dict): + return {k: sel(v, kept) for k, v in x.items()} + if isinstance(x, (torch.Tensor, np.ndarray)): + return x[kept] + if isinstance(x, (tuple, list)): + return type(x)([x[k] for k in kept]) + + +def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): + # number of images + n = max(max(e) for e in edges) + 1 + + kept = [] + for e, (i, j) in enumerate(edges): + dis = abs(i - j) + if cyclic: + dis = min(dis, abs(i + n - j), abs(i - n - j)) + if dis <= seq_dis_thr: + kept.append(e) + return kept + + +def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): + edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + return [pairs[i] for i in kept] + + +def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): + edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') + return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) diff --git a/imcui/third_party/mast3r/dust3r/dust3r/inference.py b/imcui/third_party/mast3r/dust3r/dust3r/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..90540486b077add90ca50f62a5072e082cb2f2d7 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/inference.py @@ -0,0 +1,150 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed for the inference +# -------------------------------------------------------- +import tqdm +import torch +from dust3r.utils.device import to_cpu, collate_with_cat +from dust3r.utils.misc import invalid_to_nans +from dust3r.utils.geometry import depthmap_to_pts3d, geotrf + + +def _interleave_imgs(img1, img2): + res = {} + for key, value1 in img1.items(): + value2 = img2[key] + if isinstance(value1, torch.Tensor): + value = torch.stack((value1, value2), dim=1).flatten(0, 1) + else: + value = [x for pair in zip(value1, value2) for x in pair] + res[key] = value + return res + + +def make_batch_symmetric(batch): + view1, view2 = batch + view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) + return view1, view2 + + +def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): + view1, view2 = batch + ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng']) + for view in batch: + for name in view.keys(): # pseudo_focal + if name in ignore_keys: + continue + view[name] = view[name].to(device, non_blocking=True) + + if symmetrize_batch: + view1, view2 = make_batch_symmetric(batch) + + with torch.cuda.amp.autocast(enabled=bool(use_amp)): + pred1, pred2 = model(view1, view2) + + # loss is supposed to be symmetric + with torch.cuda.amp.autocast(enabled=False): + loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None + + result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) + return result[ret] if ret else result + + +@torch.no_grad() +def inference(pairs, model, device, batch_size=8, verbose=True): + if verbose: + print(f'>> Inference with model on {len(pairs)} image pairs') + result = [] + + # first, check if all images have the same size + multiple_shapes = not (check_if_same_size(pairs)) + if multiple_shapes: # force bs=1 + batch_size = 1 + + for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose): + res = loss_of_one_batch(collate_with_cat(pairs[i:i + batch_size]), model, None, device) + result.append(to_cpu(res)) + + result = collate_with_cat(result, lists=multiple_shapes) + + return result + + +def check_if_same_size(pairs): + shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs] + shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs] + return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2) + + +def get_pred_pts3d(gt, pred, use_pose=False): + if 'depth' in pred and 'pseudo_focal' in pred: + try: + pp = gt['camera_intrinsics'][..., :2, 2] + except KeyError: + pp = None + pts3d = depthmap_to_pts3d(**pred, pp=pp) + + elif 'pts3d' in pred: + # pts3d from my camera + pts3d = pred['pts3d'] + + elif 'pts3d_in_other_view' in pred: + # pts3d from the other camera, already transformed + assert use_pose is True + return pred['pts3d_in_other_view'] # return! + + if use_pose: + camera_pose = pred.get('camera_pose') + assert camera_pose is not None + pts3d = geotrf(camera_pose, pts3d) + + return pts3d + + +def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None): + assert gt_pts1.ndim == pr_pts1.ndim == 4 + assert gt_pts1.shape == pr_pts1.shape + if gt_pts2 is not None: + assert gt_pts2.ndim == pr_pts2.ndim == 4 + assert gt_pts2.shape == pr_pts2.shape + + # concat the pointcloud + nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) + nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None + + pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) + pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None + + all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1 + all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 + + dot_gt_pr = (all_pr * all_gt).sum(dim=-1) + dot_gt_gt = all_gt.square().sum(dim=-1) + + if fit_mode.startswith('avg'): + # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + elif fit_mode.startswith('median'): + scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values + elif fit_mode.startswith('weiszfeld'): + # init scaling with l2 closed form + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip_(min=1e-8).reciprocal() + # update the scaling with the new weights + scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) + else: + raise ValueError(f'bad {fit_mode=}') + + if fit_mode.endswith('stop_grad'): + scaling = scaling.detach() + + scaling = scaling.clip(min=1e-3) + # assert scaling.isfinite().all(), bb() + return scaling diff --git a/imcui/third_party/mast3r/dust3r/dust3r/losses.py b/imcui/third_party/mast3r/dust3r/dust3r/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8febff1a2dd674e759bcf83d023099a59cc934 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/losses.py @@ -0,0 +1,299 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Implementation of DUSt3R training losses +# -------------------------------------------------------- +from copy import copy, deepcopy +import torch +import torch.nn as nn + +from dust3r.inference import get_pred_pts3d, find_opt_scaling +from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud +from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale + + +def Sum(*losses_and_masks): + loss, mask = losses_and_masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + return losses_and_masks + else: + # we are returning the global loss + for loss2, mask2 in losses_and_masks[1:]: + loss = loss + loss2 + return loss + + +class BaseCriterion(nn.Module): + def __init__(self, reduction='mean'): + super().__init__() + self.reduction = reduction + + +class LLoss (BaseCriterion): + """ L-norm loss + """ + + def forward(self, a, b): + assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' + dist = self.distance(a, b) + assert dist.ndim == a.ndim - 1 # one dimension less + if self.reduction == 'none': + return dist + if self.reduction == 'sum': + return dist.sum() + if self.reduction == 'mean': + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f'bad {self.reduction=} mode') + + def distance(self, a, b): + raise NotImplementedError() + + +class L21Loss (LLoss): + """ Euclidean distance between 3d points """ + + def distance(self, a, b): + return torch.norm(a - b, dim=-1) # normalized L2 distance + + +L21 = L21Loss() + + +class Criterion (nn.Module): + def __init__(self, criterion=None): + super().__init__() + assert isinstance(criterion, BaseCriterion), f'{criterion} is not a proper criterion!' + self.criterion = copy(criterion) + + def get_name(self): + return f'{type(self).__name__}({self.criterion})' + + def with_reduction(self, mode='none'): + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = mode # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss (nn.Module): + """ Easily combinable losses (also keep track of individual loss values): + loss = MyLoss1() + 0.1*MyLoss2() + Usage: + Inherit from this class and override get_name() and compute_loss() + """ + + def __init__(self): + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + def __mul__(self, alpha): + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + __rmul__ = __mul__ # same + + def __add__(self, loss2): + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + # find the end of the chain + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + name = self.get_name() + if self._alpha != 1: + name = f'{self._alpha:g}*{name}' + if self._loss2: + name = f'{name} + {self._loss2}' + return name + + def forward(self, *args, **kwargs): + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class Regr3D (Criterion, MultiLoss): + """ Ensure that all 3D points are correct. + Asymmetric loss: view1 is supposed to be the anchor. + + P1 = RT1 @ D1 + P2 = RT2 @ D2 + loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) + loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) + = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) + """ + + def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False): + super().__init__(criterion) + self.norm_mode = norm_mode + self.gt_scale = gt_scale + + def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): + # everything is normalized w.r.t. camera of view1 + in_camera1 = inv(gt1['camera_pose']) + gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3 + gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3 + + valid1 = gt1['valid_mask'].clone() + valid2 = gt2['valid_mask'].clone() + + if dist_clip is not None: + # points that are too far-away == invalid + dis1 = gt_pts1.norm(dim=-1) # (B, H, W) + dis2 = gt_pts2.norm(dim=-1) # (B, H, W) + valid1 = valid1 & (dis1 <= dist_clip) + valid2 = valid2 & (dis2 <= dist_clip) + + pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) + pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True) + + # normalize 3d points + if self.norm_mode: + pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) + if self.norm_mode and not self.gt_scale: + gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) + + return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {} + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) + # loss on img1 side + l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) + # loss on gt2 side + l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2]) + self_name = type(self).__name__ + details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())} + return Sum((l1, mask1), (l2, mask2)), (details | monitoring) + + +class ConfLoss (MultiLoss): + """ Weighted regression by learned confidence. + Assuming the input pixel_loss is a pixel-level regression loss. + + Principle: + high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) + low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) + + alpha: hyperparameter + """ + + def __init__(self, pixel_loss, alpha=1): + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction('none') + + def get_name(self): + return f'ConfLoss({self.pixel_loss})' + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + # compute per-pixel loss + ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) + if loss1.numel() == 0: + print('NO VALID POINTS in img1', force=True) + if loss2.numel() == 0: + print('NO VALID POINTS in img2', force=True) + + # weight by confidence + conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) + conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) + conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 + conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 + + # average + nan protection (in case of no valid pixels at all) + conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 + conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 + + return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) + + +class Regr3D_ShiftInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute unnormalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # compute median depth + gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] + pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] + gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] + pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] + + # subtract the median depth + gt_z1 -= gt_shift_z + gt_z2 -= gt_shift_z + pred_z1 -= pred_shift_z + pred_z2 -= pred_shift_z + + # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach()) + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + if gt_scale == True: enforce the prediction to take the same scale than GT + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute depth-normalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # measure scene scale + _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) + _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) + + # prevent predictions to be in a ridiculous range + pred_scale = pred_scale.clip(min=1e-3, max=1e3) + + # subtract the median depth + if self.gt_scale: + pred_pts1 *= gt_scale / pred_scale + pred_pts2 *= gt_scale / pred_scale + # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean()) + else: + gt_pts1 /= gt_scale + gt_pts2 /= gt_scale + pred_pts1 /= pred_scale + pred_pts2 /= pred_scale + # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()) + + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): + # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv + pass diff --git a/imcui/third_party/mast3r/dust3r/dust3r/model.py b/imcui/third_party/mast3r/dust3r/dust3r/model.py new file mode 100644 index 0000000000000000000000000000000000000000..41c3a4f78eb5fbafdeb7ab8523468de320886c64 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/model.py @@ -0,0 +1,210 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R model class +# -------------------------------------------------------- +from copy import deepcopy +import torch +import os +from packaging import version +import huggingface_hub + +from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape +from .heads import head_factory +from dust3r.patch_embed import get_patch_embed + +import dust3r.utils.path_to_croco # noqa: F401 +from models.croco import CroCoNet # noqa + +inf = float('inf') + +hf_version_number = huggingface_hub.__version__ +assert version.parse(hf_version_number) >= version.parse("0.22.0"), ("Outdated huggingface_hub version, " + "please reinstall requirements.txt") + + +def load_model(model_path, device, verbose=True): + if verbose: + print('... loading model from', model_path) + ckpt = torch.load(model_path, map_location='cpu') + args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if 'landscape_only' not in args: + args = args[:-1] + ', landscape_only=False)' + else: + args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') + assert "landscape_only=False" in args + if verbose: + print(f"instantiating : {args}") + net = eval(args) + s = net.load_state_dict(ckpt['model'], strict=False) + if verbose: + print(s) + return net.to(device) + + +class AsymmetricCroCo3DStereo ( + CroCoNet, + huggingface_hub.PyTorchModelHubMixin, + library_name="dust3r", + repo_url="https://github.com/naver/dust3r", + tags=["image-to-3d"], +): + """ Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__(self, + output_mode='pts3d', + head_type='linear', + depth_mode=('exp', -inf, inf), + conf_mode=('exp', 1, inf), + freeze='none', + landscape_only=True, + patch_embed_cls='PatchEmbedDust3R', # PatchEmbedDust3R or ManyAR_PatchEmbed + **croco_kwargs): + self.patch_embed_cls = patch_embed_cls + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + + # dust3r specific initialization + self.dec_blocks2 = deepcopy(self.dec_blocks) + self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) + self.set_freeze(freeze) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device='cpu') + else: + try: + model = super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) + except TypeError as e: + raise Exception(f'tried to load {pretrained_model_name_or_path} from huggingface, but failed') + return model + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) + + def load_state_dict(self, ckpt, **kw): + # duplicate all weights for the second decoder if not present + new_ckpt = dict(ckpt) + if not any(k.startswith('dec_blocks2') for k in ckpt): + for key, value in ckpt.items(): + if key.startswith('dec_blocks'): + new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value + return super().load_state_dict(new_ckpt, **kw) + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + 'none': [], + 'mask': [self.mask_token], + 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], + } + freeze_all_params(to_be_frozen[freeze]) + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head """ + return + + def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, + **kw): + assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ + f'{img_size=} must be multiple of {patch_size=}' + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + # allocate heads + self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + # magic wrapper + self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) + self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) + + def _encode_image(self, image, true_shape): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + + # add positional embedding without cls token + assert self.enc_pos_embed is None + + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, pos) + + x = self.enc_norm(x) + return x, pos, None + + def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): + if img1.shape[-2:] == img2.shape[-2:]: + out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), + torch.cat((true_shape1, true_shape2), dim=0)) + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + else: + out, pos, _ = self._encode_image(img1, true_shape1) + out2, pos2, _ = self._encode_image(img2, true_shape2) + return out, out2, pos, pos2 + + def _encode_symmetrized(self, view1, view2): + img1 = view1['img'] + img2 = view2['img'] + B = img1.shape[0] + # Recover true_shape when available, otherwise assume that the img shape is the true one + shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) + shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) + # warning! maybe the images have different portrait/landscape orientations + + if is_symmetrized(view1, view2): + # computing half of forward pass!' + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2]) + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2): + final_output = [(f1, f2)] # before projection + + # project to decoder dim + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + final_output.append((f1, f2)) + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + # img1 side + f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) + # img2 side + f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) + # store the result + final_output.append((f1, f2)) + + # normalize last output + del final_output[1] # duplicate with final_output[0] + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + # img_shape = tuple(map(int, img_shape)) + head = getattr(self, f'head{head_num}') + return head(decout, img_shape) + + def forward(self, view1, view2): + # encode the two images --> B,S,D + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2) + + # combine all ref images into object-centric representation + dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) + + with torch.cuda.amp.autocast(enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2['pts3d_in_other_view'] = res2.pop('pts3d') # predict view2's pts3d in view1's frame + return res1, res2 diff --git a/imcui/third_party/mast3r/dust3r/dust3r/optim_factory.py b/imcui/third_party/mast3r/dust3r/dust3r/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9c16e0e0fda3fd03c3def61abc1f354f75c584 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/optim_factory.py @@ -0,0 +1,14 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# optimization functions +# -------------------------------------------------------- + + +def adjust_learning_rate_by_lr(optimizer, lr): + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr diff --git a/imcui/third_party/mast3r/dust3r/dust3r/patch_embed.py b/imcui/third_party/mast3r/dust3r/dust3r/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..07bb184bccb9d16657581576779904065d2dc857 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/patch_embed.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# PatchEmbed implementation for DUST3R, +# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio +# -------------------------------------------------------- +import torch +import dust3r.utils.path_to_croco # noqa: F401 +from models.blocks import PatchEmbed # noqa + + +def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): + assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] + patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) + return patch_embed + + +class PatchEmbedDust3R(PatchEmbed): + def forward(self, x, **kw): + B, C, H, W = x.shape + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + +class ManyAR_PatchEmbed (PatchEmbed): + """ Handle images with non-square aspect ratio. + All images in the same batch have the same aspect ratio. + true_shape = [(height, width) ...] indicates the actual shape of each image. + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + self.embed_dim = embed_dim + super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) + + def forward(self, img, true_shape): + B, C, H, W = img.shape + assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" + + # size expressed in tokens + W //= self.patch_size[0] + H //= self.patch_size[1] + n_tokens = H * W + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # allocate result + x = img.new_zeros((B, n_tokens, self.embed_dim)) + pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) + + # linear projection, transposed if necessary + x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() + x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() + + pos[is_landscape] = self.position_getter(1, H, W, pos.device) + pos[is_portrait] = self.position_getter(1, W, H, pos.device) + + x = self.norm(x) + return x, pos diff --git a/imcui/third_party/mast3r/dust3r/dust3r/post_process.py b/imcui/third_party/mast3r/dust3r/dust3r/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..550a9b41025ad003228ef16f97d045fc238746e4 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/post_process.py @@ -0,0 +1,60 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities for interpreting the DUST3R output +# -------------------------------------------------------- +import numpy as np +import torch +from dust3r.utils.geometry import xy_grid + + +def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf): + """ Reprojection method, for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + # centered pixel grid + pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 + pts3d = pts3d.flatten(1, 2) # (B, HW, 3) + + if focal_mode == 'median': + with torch.no_grad(): + # direct estimation of focal + u, v = pixels.unbind(dim=-1) + x, y, z = pts3d.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + # assume square pixels, hence same focal for X and Y + f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) + focal = torch.nanmedian(f_votes, dim=-1).values + + elif focal_mode == 'weiszfeld': + # init focal with l2 closed form + # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| + xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) + + dot_xy_px = (xy_over_z * pixels).sum(dim=-1) + dot_xy_xy = xy_over_z.square().sum(dim=-1) + + focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) + + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip(min=1e-8).reciprocal() + # update the scaling with the new weights + focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) + else: + raise ValueError(f'bad {focal_mode=}') + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) + # print(focal) + return focal diff --git a/imcui/third_party/mast3r/dust3r/dust3r/training.py b/imcui/third_party/mast3r/dust3r/dust3r/training.py new file mode 100644 index 0000000000000000000000000000000000000000..53af9764ebb03a0083c22294298ed674e9164edc --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/training.py @@ -0,0 +1,377 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# training code for DUSt3R +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import argparse +import datetime +import json +import numpy as np +import os +import sys +import time +import math +from collections import defaultdict +from pathlib import Path +from typing import Sized + +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + +from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model +from dust3r.datasets import get_data_loader # noqa +from dust3r.losses import * # noqa: F401, needed when loading the model +from dust3r.inference import loss_of_one_batch # noqa + +import dust3r.utils.path_to_croco # noqa: F401 +import croco.utils.misc as misc # noqa +from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa + + +def get_args_parser(): + parser = argparse.ArgumentParser('DUST3R training', add_help=False) + # model and criterion + parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')", + type=str, help="string containing the model to build") + parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint') + parser.add_argument('--train_criterion', default="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)", + type=str, help="train criterion") + parser.add_argument('--test_criterion', default=None, type=str, help="test criterion") + + # dataset + parser.add_argument('--train_dataset', required=True, type=str, help="training set") + parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set") + + # training + parser.add_argument('--seed', default=0, type=int, help="Random seed") + parser.add_argument('--batch_size', default=64, type=int, + help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") + parser.add_argument('--accum_iter', default=1, type=int, + help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") + parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") + + parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") + parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', + help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--min_lr', type=float, default=0., metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') + + parser.add_argument('--amp', type=int, default=0, + choices=[0, 1], help="Use Automatic Mixed Precision for pretraining") + parser.add_argument("--disable_cudnn_benchmark", action='store_true', default=False, + help="set cudnn.benchmark = False") + # others + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency') + parser.add_argument('--save_freq', default=1, type=int, + help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') + parser.add_argument('--keep_freq', default=20, type=int, + help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') + parser.add_argument('--print_freq', default=20, type=int, + help='frequence (number of iterations) to print infos while training') + + # output dir + parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output") + return parser + + +def train(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # auto resume + last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = not args.disable_cudnn_benchmark + + # training dataset and loader + print('Building train dataset {:s}'.format(args.train_dataset)) + # dataset and loader + data_loader_train = build_dataset(args.train_dataset, args.batch_size, args.num_workers, test=False) + print('Building test dataset {:s}'.format(args.train_dataset)) + data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True) + for dataset in args.test_dataset.split('+')} + + # model + print('Loading model: {:s}'.format(args.model)) + model = eval(args.model) + print(f'>> Creating train criterion = {args.train_criterion}') + train_criterion = eval(args.train_criterion).to(device) + print(f'>> Creating test criterion = {args.test_criterion or args.train_criterion}') + test_criterion = eval(args.test_criterion or args.criterion).to(device) + + model.to(device) + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + if args.pretrained and not args.resume: + print('Loading pretrained: ', args.pretrained) + ckpt = torch.load(args.pretrained, map_location=device) + print(model.load_state_dict(ckpt['model'], strict=False)) + del ckpt # in case it occupies memory + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True) + model_without_ddp = model.module + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + def write_log_stats(epoch, train_stats, test_stats): + if misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + + log_stats = dict(epoch=epoch, **{f'train_{k}': v for k, v in train_stats.items()}) + for test_name in data_loader_test: + if test_name not in test_stats: + continue + log_stats.update({test_name + '_' + k: v for k, v in test_stats[test_name].items()}) + + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + def save_model(epoch, fname, best_so_far): + misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, fname=fname, best_so_far=best_so_far) + + best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp, + optimizer=optimizer, loss_scaler=loss_scaler) + if best_so_far is None: + best_so_far = float('inf') + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + train_stats = test_stats = {} + for epoch in range(args.start_epoch, args.epochs + 1): + + # Save immediately the last checkpoint + if epoch > args.start_epoch: + if args.save_freq and epoch % args.save_freq == 0 or epoch == args.epochs: + save_model(epoch - 1, 'last', best_so_far) + + # Test on multiple datasets + new_best = False + if (epoch > 0 and args.eval_freq > 0 and epoch % args.eval_freq == 0): + test_stats = {} + for test_name, testset in data_loader_test.items(): + stats = test_one_epoch(model, test_criterion, testset, + device, epoch, log_writer=log_writer, args=args, prefix=test_name) + test_stats[test_name] = stats + + # Save best of all + if stats['loss_med'] < best_so_far: + best_so_far = stats['loss_med'] + new_best = True + + # Save more stuff + write_log_stats(epoch, train_stats, test_stats) + + if epoch > args.start_epoch: + if args.keep_freq and epoch % args.keep_freq == 0: + save_model(epoch - 1, str(epoch), best_so_far) + if new_best: + save_model(epoch - 1, 'best', best_so_far) + if epoch >= args.epochs: + break # exit after writing last test to disk + + # Train + train_stats = train_one_epoch( + model, train_criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + log_writer=log_writer, + args=args) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + save_final_model(args, args.epochs, model_without_ddp, best_so_far=best_so_far) + + +def save_final_model(args, epoch, model_without_ddp, best_so_far=None): + output_dir = Path(args.output_dir) + checkpoint_path = output_dir / 'checkpoint-final.pth' + to_save = { + 'args': args, + 'model': model_without_ddp if isinstance(model_without_ddp, dict) else model_without_ddp.cpu().state_dict(), + 'epoch': epoch + } + if best_so_far is not None: + to_save['best_so_far'] = best_so_far + print(f'>> Saving model to {checkpoint_path} ...') + misc.save_on_master(to_save, checkpoint_path) + + +def build_dataset(dataset, batch_size, num_workers, test=False): + split = ['Train', 'Test'][test] + print(f'Building {split} Data loader for dataset: ', dataset) + loader = get_data_loader(dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_mem=True, + shuffle=not (test), + drop_last=not (test)) + + print(f"{split} dataset length: ", len(loader)) + return loader + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Sized, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + args, + log_writer=None): + assert torch.backends.cuda.matmul.allow_tf32 == True + + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + accum_iter = args.accum_iter + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): + data_loader.sampler.set_epoch(epoch) + + optimizer.zero_grad() + + for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + epoch_f = epoch + data_iter_step / len(data_loader) + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate(optimizer, epoch_f, args) + + loss_tuple = loss_of_one_batch(batch, model, criterion, device, + symmetrize_batch=True, + use_amp=bool(args.amp), ret='loss') + loss, loss_details = loss_tuple # criterion returns two values + loss_value = float(loss) + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value), force=True) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + del loss + del batch + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(epoch=epoch_f) + metric_logger.update(lr=lr) + metric_logger.update(loss=loss_value, **loss_details) + + if (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0: + loss_value_reduce = misc.all_reduce_mean(loss_value) # MUST BE EXECUTED BY ALL NODES + if log_writer is None: + continue + """ We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int(epoch_f * 1000) + log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('train_lr', lr, epoch_1000x) + log_writer.add_scalar('train_iter', epoch_1000x, epoch_1000x) + for name, val in loss_details.items(): + log_writer.add_scalar('train_' + name, val, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Sized, device: torch.device, epoch: int, + args, log_writer=None, prefix='test'): + + model.eval() + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9)) + header = 'Test Epoch: [{}]'.format(epoch) + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): + data_loader.dataset.set_epoch(epoch) + if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): + data_loader.sampler.set_epoch(epoch) + + for _, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + loss_tuple = loss_of_one_batch(batch, model, criterion, device, + symmetrize_batch=True, + use_amp=bool(args.amp), ret='loss') + loss_value, loss_details = loss_tuple # criterion returns two values + metric_logger.update(loss=float(loss_value), **loss_details) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + aggs = [('avg', 'global_avg'), ('med', 'median')] + results = {f'{k}_{tag}': getattr(meter, attr) for k, meter in metric_logger.meters.items() for tag, attr in aggs} + + if log_writer is not None: + for name, val in results.items(): + log_writer.add_scalar(prefix + '_' + name, val, 1000 * epoch) + + return results diff --git a/imcui/third_party/mast3r/dust3r/dust3r/utils/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/mast3r/dust3r/dust3r/utils/device.py b/imcui/third_party/mast3r/dust3r/dust3r/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b6a74dac05a2e1ba3a2b2f0faa8cea08ece745 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/utils/device.py @@ -0,0 +1,76 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + ''' + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == 'numpy': + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): return todevice(x, 'numpy') +def to_cpu(x): return todevice(x, 'cpu') +def to_cuda(x): return todevice(x, 'cuda') + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) + + # otherwise, we just chain lists + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/imcui/third_party/mast3r/dust3r/dust3r/utils/geometry.py b/imcui/third_party/mast3r/dust3r/dust3r/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..ce365faf2acb97ffaafa1b80cb8ee0c28de0b6d6 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/utils/geometry.py @@ -0,0 +1,366 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# geometry utilitary functions +# -------------------------------------------------------- +import torch +import numpy as np +from scipy.spatial import cKDTree as KDTree + +from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans +from dust3r.utils.device import to_numpy + + +def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): + """ Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing='xy') + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = (depthmap > 0.0) + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + + return X_world, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None, ret_factor=False): + """ renorm pointmaps pts1, pts2 with norm_mode + """ + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split('_') + + if norm_mode == 'avg': + # gather all points together (joint normalization) + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == 'dis': + pass # do nothing + elif dis_mode == 'log1p': + all_dis = torch.log1p(all_dis) + elif dis_mode == 'warp-log1p': + # actually warp input points before normalizing them + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, :W1 * H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1 * H1:].view(-1, H2, W2, 1) + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f'bad {dis_mode=}') + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + # gather all points together (joint normalization) + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + + if norm_mode == 'avg': + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == 'median': + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == 'sqrt': + norm_factor = all_dis.sqrt().nanmean(dim=1)**2 + else: + raise ValueError(f'bad {norm_mode=}') + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + if ret_factor: + res = res + (norm_factor,) + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + # set invalid points to NaN + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True): + # set invalid points to NaN + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))) + reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) diff --git a/imcui/third_party/mast3r/dust3r/dust3r/utils/image.py b/imcui/third_party/mast3r/dust3r/dust3r/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6312a346df919ae6a0424504d824ef813fea250f --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/utils/image.py @@ -0,0 +1,126 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions about images (loading/converting...) +# -------------------------------------------------------- +import os +import torch +import numpy as np +import PIL.Image +from PIL.ImageOps import exif_transpose +import torchvision.transforms as tvf +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa + +try: + from pillow_heif import register_heif_opener # noqa + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +def img_to_arr( img ): + if isinstance(img, str): + img = imread_cv2(img) + return img + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """ Open an image or a depthmap with opencv-python. + """ + if path.endswith(('.exr', 'EXR')): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f'Could not load image={path} with {options=}') + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) + return img.resize(new_size, interp) + + +def load_images(folder_or_list, size, square_ok=False, verbose=True): + """ open and convert all images in a list or folder to proper input format for DUSt3R + """ + if isinstance(folder_or_list, str): + if verbose: + print(f'>> Loading images from {folder_or_list}') + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f'>> Loading a list of {len(folder_or_list)} images') + root, folder_content = '', folder_or_list + + else: + raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') + + supported_images_extensions = ['.jpg', '.jpeg', '.png'] + if heif_support_enabled: + supported_images_extensions += ['.heic', '.heif'] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB') + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1))) + else: + # resize long side to 512 + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W//2, H//2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx-half, cy-half, cx+half, cy+half)) + else: + halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 + if not (square_ok) and W == H: + halfh = 3*halfw/4 + img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) + + W2, H2 = img.size + if verbose: + print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') + imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( + [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)))) + + assert imgs, 'no images foud at '+root + if verbose: + print(f' (Found {len(imgs)} images)') + return imgs diff --git a/imcui/third_party/mast3r/dust3r/dust3r/utils/misc.py b/imcui/third_party/mast3r/dust3r/dust3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..88c4d2dab6d5c14021ed9ed6646c3159a3a4637b --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/utils/misc.py @@ -0,0 +1,121 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import torch + + +def fill_default_args(kwargs, func): + import inspect # a bit hacky but it works reliably + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + # module is directly a parameter + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1['instance'] + y = gt2['instance'] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def flip(tensor): + """ flip so that tensor[0::2] <=> tensor[1::2] """ + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """ Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + def wrapper_no(decout, true_shape): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W)) + return res + + def wrapper_yes(decout, true_shape): + B = len(true_shape) + # by definition, the batch is in landscape mode so W >= H + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # true_shape = true_shape.cpu() + if is_landscape.all(): + return head(decout, (H, W)) + if is_portrait.all(): + return transposed(head(decout, (W, H))) + + # batch is a mix of both portraint & landscape + def selout(ar): return [d[ar] for d in decout] + l_result = head(selout(is_landscape), (H, W)) + p_result = transposed(head(selout(is_portrait), (W, H))) + + # allocate full result + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float('nan') + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/imcui/third_party/mast3r/dust3r/dust3r/utils/parallel.py b/imcui/third_party/mast3r/dust3r/dust3r/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..06ae7fefdb9d2298929f0cbc20dfbc57eb7d7f7b --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/utils/parallel.py @@ -0,0 +1,79 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for multiprocessing +# -------------------------------------------------------- +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import cpu_count + + +def parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw): + """ tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + while workers <= 0: + workers += cpu_count() + if workers == 1: + front_num = float('inf') + + # convert into an iterable + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + # sequential execution first + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # end of the iterable + front.append(function(*a) if star_args else function(**a) if kw_args else function(a)) + + # then parallel execution + out = [] + with Pool(workers) as pool: + # Pass the elements of args into function + if star_args: + futures = pool.imap(starcall, [(function, a) for a in args]) + elif kw_args: + futures = pool.imap(starstarcall, [(function, a) for a in args]) + else: + futures = pool.imap(function, args) + # Print out the progress as tasks complete + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """ Same as parallel_threads, with processes + """ + import multiprocessing as mp + kwargs['Pool'] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(*args) + + +def starstarcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(**args) diff --git a/imcui/third_party/mast3r/dust3r/dust3r/utils/path_to_croco.py b/imcui/third_party/mast3r/dust3r/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000000000000000000000000000000000000..39226ce6bc0e1993ba98a22096de32cb6fa916b4 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/utils/path_to_croco.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# CroCo submodule import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco')) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models') +# check the presence of models directory in repo to be sure its cloned +if path.isdir(CROCO_MODELS_PATH): + # workaround for sibling import + sys.path.insert(0, CROCO_REPO_PATH) +else: + raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?") diff --git a/imcui/third_party/mast3r/dust3r/dust3r/viz.py b/imcui/third_party/mast3r/dust3r/dust3r/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..9150e8b850d9f1e6bf9ddf6e865d34fc743e276a --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r/viz.py @@ -0,0 +1,381 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Visualization utilities using trimesh +# -------------------------------------------------------- +import PIL.Image +import numpy as np +from scipy.spatial.transform import Rotation +import torch + +from dust3r.utils.geometry import geotrf, get_med_dist_between_poses, depthmap_to_absolute_camera_coordinates +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb, img_to_arr + +try: + import trimesh +except ImportError: + print('/!\\ module trimesh is not installed, cannot visualize results /!\\') + + + +def cat_3d(vecs): + if isinstance(vecs, (np.ndarray, torch.Tensor)): + vecs = [vecs] + return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)]) + + +def show_raw_pointcloud(pts3d, colors, point_size=2): + scene = trimesh.Scene() + + pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors)) + scene.add_geometry(pct) + + scene.show(line_settings={'point_size': point_size}) + + +def pts3d_to_trimesh(img, pts3d, valid=None): + H, W, THREE = img.shape + assert THREE == 3 + assert img.shape == pts3d.shape + + vertices = pts3d.reshape(-1, 3) + + # make squares: each pixel == 2 triangles + idx = np.arange(len(vertices)).reshape(H, W) + idx1 = idx[:-1, :-1].ravel() # top-left corner + idx2 = idx[:-1, +1:].ravel() # right-left corner + idx3 = idx[+1:, :-1].ravel() # bottom-left corner + idx4 = idx[+1:, +1:].ravel() # bottom-right corner + faces = np.concatenate(( + np.c_[idx1, idx2, idx3], + np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling) + np.c_[idx2, idx3, idx4], + np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling) + ), axis=0) + + # prepare triangle colors + face_colors = np.concatenate(( + img[:-1, :-1].reshape(-1, 3), + img[:-1, :-1].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3) + ), axis=0) + + # remove invalid faces + if valid is not None: + assert valid.shape == (H, W) + valid_idxs = valid.ravel() + valid_faces = valid_idxs[faces].all(axis=-1) + faces = faces[valid_faces] + face_colors = face_colors[valid_faces] + + assert len(faces) == len(face_colors) + return dict(vertices=vertices, face_colors=face_colors, faces=faces) + + +def cat_meshes(meshes): + vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes]) + n_vertices = np.cumsum([0]+[len(v) for v in vertices]) + for i in range(len(faces)): + faces[i][:] += n_vertices[i] + + vertices = np.concatenate(vertices) + colors = np.concatenate(colors) + faces = np.concatenate(faces) + return dict(vertices=vertices, face_colors=colors, faces=faces) + + +def show_duster_pairs(view1, view2, pred1, pred2): + import matplotlib.pyplot as pl + pl.ion() + + for e in range(len(view1['instance'])): + i = view1['idx'][e] + j = view2['idx'][e] + img1 = rgb(view1['img'][e]) + img2 = rgb(view2['img'][e]) + conf1 = pred1['conf'][e].squeeze() + conf2 = pred2['conf'][e].squeeze() + score = conf1.mean()*conf2.mean() + print(f">> Showing pair #{e} {i}-{j} {score=:g}") + pl.clf() + pl.subplot(221).imshow(img1) + pl.subplot(223).imshow(img2) + pl.subplot(222).imshow(conf1, vmin=1, vmax=30) + pl.subplot(224).imshow(conf2, vmin=1, vmax=30) + pts1 = pred1['pts3d'][e] + pts2 = pred2['pts3d_in_other_view'][e] + pl.subplots_adjust(0, 0, 1, 1, 0, 0) + if input('show pointcloud? (y/n) ') == 'y': + show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5) + + +def auto_cam_size(im_poses): + return 0.1 * get_med_dist_between_poses(im_poses) + + +class SceneViz: + def __init__(self): + self.scene = trimesh.Scene() + + def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None): + image = img_to_arr(image) + + # make up some intrinsics + if intrinsics is None: + H, W, THREE = image.shape + focal = max(H, W) + intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]]) + + # compute 3d points + pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world) + + return self.add_pointcloud(pts3d, image, mask=(depth 150) + mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) + mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) + + # Morphological operations + kernel = np.ones((5, 5), np.uint8) + mask2 = ndimage.binary_opening(mask, structure=kernel) + + # keep only largest CC + _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8) + cc_sizes = stats[1:, cv2.CC_STAT_AREA] + order = cc_sizes.argsort()[::-1] # bigger first + i = 0 + selection = [] + while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: + selection.append(1 + order[i]) + i += 1 + mask3 = np.in1d(labels, selection).reshape(labels.shape) + + # Apply mask + return torch.from_numpy(mask3) diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/__init__.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..566926b1e248e4b64fc5182031af634435bb8601 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/__init__.py @@ -0,0 +1,6 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +from .sevenscenes import VislocSevenScenes +from .cambridge_landmarks import VislocCambridgeLandmarks +from .aachen_day_night import VislocAachenDayNight +from .inloc import VislocInLoc diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/aachen_day_night.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/aachen_day_night.py new file mode 100644 index 0000000000000000000000000000000000000000..159548e8b51a1b5872a2392cd9107ff96e40e801 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/aachen_day_night.py @@ -0,0 +1,24 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# AachenDayNight dataloader +# -------------------------------------------------------- +import os +from dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset + + +class VislocAachenDayNight(BaseVislocColmapDataset): + def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False): + assert subscene in [None, '', 'day', 'night', 'all'] + self.subscene = subscene + image_path = os.path.join(root, 'images') + map_path = os.path.join(root, 'mapping/colmap/reconstruction') + query_path = os.path.join(root, 'kapture', 'query') + pairsfile_path = os.path.join(root, 'pairsfile/query', pairsfile + '.txt') + super().__init__(image_path=image_path, map_path=map_path, + query_path=query_path, pairsfile_path=pairsfile_path, + topk=topk, cache_sfm=cache_sfm) + self.scenes = [filename for filename in self.scenes if filename in self.pairs] + if self.subscene == 'day' or self.subscene == 'night': + self.scenes = [filename for filename in self.scenes if self.subscene in filename] diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_colmap.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc2d64f69fb0954a148f0f4170508fe2045e046 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_colmap.py @@ -0,0 +1,282 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class for colmap / kapture +# -------------------------------------------------------- +import os +import numpy as np +from tqdm import tqdm +import collections +import pickle +import PIL.Image +import torch +from scipy.spatial.transform import Rotation +import torchvision.transforms as tvf + +from kapture.core import CameraType +from kapture.io.csv import kapture_from_dir +from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file + +from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d +from dust3r_visloc.datasets.base_dataset import BaseVislocDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import colmap_to_opencv_intrinsics + +KaptureSensor = collections.namedtuple('Sensor', 'sensor_params camera_params') + + +def kapture_to_opencv_intrinsics(sensor): + """ + Convert from Kapture to OpenCV parameters. + Warning: we assume that the camera and pixel coordinates follow Colmap conventions here. + Args: + sensor: Kapture sensor + """ + sensor_type = sensor.sensor_params[0] + if sensor_type == "SIMPLE_PINHOLE": + # Simple pinhole model. + # We still call OpenCV undistorsion however for code simplicity. + w, h, f, cx, cy = sensor.camera_params + k1 = 0 + k2 = 0 + p1 = 0 + p2 = 0 + fx = fy = f + elif sensor_type == "PINHOLE": + w, h, fx, fy, cx, cy = sensor.camera_params + k1 = 0 + k2 = 0 + p1 = 0 + p2 = 0 + elif sensor_type == "SIMPLE_RADIAL": + w, h, f, cx, cy, k1 = sensor.camera_params + k2 = 0 + p1 = 0 + p2 = 0 + fx = fy = f + elif sensor_type == "RADIAL": + w, h, f, cx, cy, k1, k2 = sensor.camera_params + p1 = 0 + p2 = 0 + fx = fy = f + elif sensor_type == "OPENCV": + w, h, fx, fy, cx, cy, k1, k2, p1, p2 = sensor.camera_params + else: + raise NotImplementedError(f"Sensor type {sensor_type} is not supported yet.") + + cameraMatrix = np.asarray([[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]], dtype=np.float32) + + # We assume that Kapture data comes from Colmap: the origin is different. + cameraMatrix = colmap_to_opencv_intrinsics(cameraMatrix) + + distCoeffs = np.asarray([k1, k2, p1, p2], dtype=np.float32) + return cameraMatrix, distCoeffs, (w, h) + + +def K_from_colmap(elems): + sensor = KaptureSensor(elems, tuple(map(float, elems[1:]))) + cameraMatrix, distCoeffs, (w, h) = kapture_to_opencv_intrinsics(sensor) + res = dict(resolution=(w, h), + intrinsics=cameraMatrix, + distortion=distCoeffs) + return res + + +def pose_from_qwxyz_txyz(elems): + qw, qx, qy, qz, tx, ty, tz = map(float, elems) + pose = np.eye(4) + pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() + pose[:3, 3] = (tx, ty, tz) + return np.linalg.inv(pose) # returns cam2world + + +class BaseVislocColmapDataset(BaseVislocDataset): + def __init__(self, image_path, map_path, query_path, pairsfile_path, topk=1, cache_sfm=False): + super().__init__() + self.topk = topk + self.num_views = self.topk + 1 + self.image_path = image_path + self.cache_sfm = cache_sfm + + self._load_sfm(map_path) + + kdata_query = kapture_from_dir(query_path) + assert kdata_query.records_camera is not None and kdata_query.trajectories is not None + + kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} + self.query_data = {'kdata': kdata_query, 'searchindex': kdata_query_searchindex} + + self.pairs = get_ordered_pairs_from_file(pairsfile_path) + self.scenes = kdata_query.records_camera.data_list() + + def _load_sfm(self, sfm_dir): + sfm_cache_path = os.path.join(sfm_dir, 'dust3r_cache.pkl') + if os.path.isfile(sfm_cache_path) and self.cache_sfm: + with open(sfm_cache_path, "rb") as f: + data = pickle.load(f) + self.img_infos = data['img_infos'] + self.points3D = data['points3D'] + return + + # load cameras + with open(os.path.join(sfm_dir, 'cameras.txt'), 'r') as f: + raw = f.read().splitlines()[3:] # skip header + + intrinsics = {} + for camera in tqdm(raw): + camera = camera.split(' ') + intrinsics[int(camera[0])] = K_from_colmap(camera[1:]) + + # load images + with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + self.img_infos = {} + for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2): + image = image.split(' ') + points = points.split(' ') + + img_name = image[-1] + current_points2D = {int(i): (float(x), float(y)) + for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} + self.img_infos[img_name] = dict(intrinsics[int(image[-2])], + path=img_name, + camera_pose=pose_from_qwxyz_txyz(image[1: -2]), + sparse_pts2d=current_points2D) + + # load 3D points + with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: + raw = f.read().splitlines() + raw = [line for line in raw if not line.startswith('#')] # skip header + + self.points3D = {} + for point in tqdm(raw): + point = point.split() + self.points3D[int(point[0])] = tuple(map(float, point[1:4])) + + if self.cache_sfm: + to_save = \ + { + 'img_infos': self.img_infos, + 'points3D': self.points3D + } + with open(sfm_cache_path, "wb") as f: + pickle.dump(to_save, f) + + def __len__(self): + return len(self.scenes) + + def _get_view_query(self, imgname): + kdata, searchindex = map(self.query_data.get, ['kdata', 'searchindex']) + + timestamp, camera_id = searchindex[imgname] + + camera_params = kdata.sensors[camera_id].camera_params + if kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_PINHOLE: + W, H, f, cx, cy = camera_params + k1 = 0 + fx = fy = f + elif kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_RADIAL: + W, H, f, cx, cy, k1 = camera_params + fx = fy = f + else: + raise NotImplementedError('not implemented') + + W, H = int(W), int(H) + intrinsics = np.float32([(fx, 0, cx), + (0, fy, cy), + (0, 0, 1)]) + intrinsics = colmap_to_opencv_intrinsics(intrinsics) + distortion = [k1, 0, 0, 0] + + if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: + cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) + else: + cam_to_world = np.eye(4, dtype=np.float32) + + # Load RGB image + rgb_image = PIL.Image.open(os.path.join(self.image_path, imgname)).convert('RGB') + rgb_image.load() + resize_func, _, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + 'rgb_rescaled': rgb_tensor, + 'to_orig': to_orig, + 'idx': 0, + 'image_name': imgname + } + return view + + def _get_view_map(self, imgname, idx): + infos = self.img_infos[imgname] + + rgb_image = PIL.Image.open(os.path.join(self.image_path, infos['path'])).convert('RGB') + rgb_image.load() + W, H = rgb_image.size + intrinsics = infos['intrinsics'] + intrinsics = colmap_to_opencv_intrinsics(intrinsics) + distortion_coefs = infos['distortion'] + + pts2d = infos['sparse_pts2d'] + sparse_pos2d = np.float32(list(pts2d.values())) # pts2d from colmap + sparse_pts3d = np.float32([self.points3D[i] for i in pts2d]) + + # store full resolution 2D->3D + sparse_pos2d_cv2 = sparse_pos2d.copy() + sparse_pos2d_cv2[:, 0] -= 0.5 + sparse_pos2d_cv2[:, 1] -= 0.5 + sparse_pos2d_int = sparse_pos2d_cv2.round().astype(np.int64) + valid = (sparse_pos2d_int[:, 0] >= 0) & (sparse_pos2d_int[:, 0] < W) & ( + sparse_pos2d_int[:, 1] >= 0) & (sparse_pos2d_int[:, 1] < H) + sparse_pos2d_int = sparse_pos2d_int[valid] + # nan => invalid + pts3d = np.full((H, W, 3), np.nan, dtype=np.float32) + pts3d[sparse_pos2d_int[:, 1], sparse_pos2d_int[:, 0]] = sparse_pts3d[valid] + pts3d = torch.from_numpy(pts3d) + + cam_to_world = infos['camera_pose'] # cam2world + + # also store resized resolution 2D->3D + resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + HR, WR = rgb_tensor.shape[1:] + _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(sparse_pos2d_cv2, sparse_pts3d, to_resize, HR, WR) + pts3d_rescaled = torch.from_numpy(pts3d_rescaled) + valid_rescaled = torch.from_numpy(valid_rescaled) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion_coefs, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + "pts3d": pts3d, + "valid": pts3d.sum(dim=-1).isfinite(), + 'rgb_rescaled': rgb_tensor, + "pts3d_rescaled": pts3d_rescaled, + "valid_rescaled": valid_rescaled, + 'to_orig': to_orig, + 'idx': idx, + 'image_name': imgname + } + return view + + def __getitem__(self, idx): + assert self.maxdim is not None and self.patch_size is not None + query_image = self.scenes[idx] + map_images = [p[0] for p in self.pairs[query_image][:self.topk]] + views = [] + views.append(self._get_view_query(query_image)) + for idx, map_image in enumerate(map_images): + views.append(self._get_view_map(map_image, idx + 1)) + return views diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_dataset.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cda3774c5ab5b668be5eecf89681abc96df5fe17 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_dataset.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class +# -------------------------------------------------------- +class BaseVislocDataset: + def __init__(self): + pass + + def set_resolution(self, model): + self.maxdim = max(model.patch_embed.img_size) + self.patch_size = model.patch_embed.patch_size + + def __len__(self): + raise NotImplementedError() + + def __getitem__(self, idx): + raise NotImplementedError() \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3e131941bf444d86a709d23e518e7b93d3d0f6 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Cambridge Landmarks dataloader +# -------------------------------------------------------- +import os +from dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset + + +class VislocCambridgeLandmarks (BaseVislocColmapDataset): + def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False): + image_path = os.path.join(root, subscene) + map_path = os.path.join(root, 'mapping', subscene, 'colmap/reconstruction') + query_path = os.path.join(root, 'kapture', subscene, 'query') + pairsfile_path = os.path.join(root, subscene, 'pairsfile/query', pairsfile + '.txt') + super().__init__(image_path=image_path, map_path=map_path, + query_path=query_path, pairsfile_path=pairsfile_path, + topk=topk, cache_sfm=cache_sfm) \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/inloc.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/inloc.py new file mode 100644 index 0000000000000000000000000000000000000000..99ed11f554203d353d0559d0589f40ec1ffbf66e --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/inloc.py @@ -0,0 +1,167 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# InLoc dataloader +# -------------------------------------------------------- +import os +import numpy as np +import torch +import PIL.Image +import scipy.io + +import kapture +from kapture.io.csv import kapture_from_dir +from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file + +from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d +from dust3r_visloc.datasets.base_dataset import BaseVislocDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import xy_grid, geotrf + + +def read_alignments(path_to_alignment): + aligns = {} + with open(path_to_alignment, "r") as fid: + while True: + line = fid.readline() + if not line: + break + if len(line) == 4: + trans_nr = line[:-1] + while line != 'After general icp:\n': + line = fid.readline() + line = fid.readline() + p = [] + for i in range(4): + elems = line.split(' ') + line = fid.readline() + for e in elems: + if len(e) != 0: + p.append(float(e)) + P = np.array(p).reshape(4, 4) + aligns[trans_nr] = P + return aligns + + +class VislocInLoc(BaseVislocDataset): + def __init__(self, root, pairsfile, topk=1): + super().__init__() + self.root = root + self.topk = topk + self.num_views = self.topk + 1 + self.maxdim = None + self.patch_size = None + + query_path = os.path.join(self.root, 'query') + kdata_query = kapture_from_dir(query_path) + assert kdata_query.records_camera is not None + kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} + self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex} + + map_path = os.path.join(self.root, 'mapping') + kdata_map = kapture_from_dir(map_path) + assert kdata_map.records_camera is not None and kdata_map.trajectories is not None + kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_map.records_camera.key_pairs()} + self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex} + + try: + self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt')) + except Exception as e: + # if using pairs from hloc + self.pairs = {} + with open(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt'), 'r') as fid: + lines = fid.readlines() + for line in lines: + splits = line.rstrip("\n\r").split(" ") + self.pairs.setdefault(splits[0].replace('query/', ''), []).append( + (splits[1].replace('database/cutouts/', ''), 1.0) + ) + + self.scenes = kdata_query.records_camera.data_list() + + self.aligns_DUC1 = read_alignments(os.path.join(self.root, 'mapping/DUC1_alignment/all_transformations.txt')) + self.aligns_DUC2 = read_alignments(os.path.join(self.root, 'mapping/DUC2_alignment/all_transformations.txt')) + + def __len__(self): + return len(self.scenes) + + def __getitem__(self, idx): + assert self.maxdim is not None and self.patch_size is not None + query_image = self.scenes[idx] + map_images = [p[0] for p in self.pairs[query_image][:self.topk]] + views = [] + dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True) + for map_image in map_images] + for idx, (imgname, data, should_load_depth) in enumerate(dataarray): + imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex']) + + timestamp, camera_id = searchindex[imgname] + + # for InLoc, SIMPLE_PINHOLE + camera_params = kdata.sensors[camera_id].camera_params + W, H, f, cx, cy = camera_params + distortion = [0, 0, 0, 0] + intrinsics = np.float32([(f, 0, cx), + (0, f, cy), + (0, 0, 1)]) + + if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: + cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) + else: + cam_to_world = np.eye(4, dtype=np.float32) + + # Load RGB image + rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB') + rgb_image.load() + + W, H = rgb_image.size + resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + 'rgb_rescaled': rgb_tensor, + 'to_orig': to_orig, + 'idx': idx, + 'image_name': imgname + } + + # Load depthmap + if should_load_depth: + depthmap_filename = os.path.join(imgpath, 'sensors/records_data', imgname + '.mat') + depthmap = scipy.io.loadmat(depthmap_filename) + + pt3d_cut = depthmap['XYZcut'] + scene_id = imgname.replace('\\', '/').split('/')[1] + if imgname.startswith('DUC1'): + pts3d_full = geotrf(self.aligns_DUC1[scene_id], pt3d_cut) + else: + pts3d_full = geotrf(self.aligns_DUC2[scene_id], pt3d_cut) + + pts3d_valid = np.isfinite(pts3d_full.sum(axis=-1)) + + pts3d = pts3d_full[pts3d_valid] + pts2d_int = xy_grid(W, H)[pts3d_valid] + pts2d = pts2d_int.astype(np.float64) + + # nan => invalid + pts3d_full[~pts3d_valid] = np.nan + pts3d_full = torch.from_numpy(pts3d_full) + view['pts3d'] = pts3d_full + view["valid"] = pts3d_full.sum(dim=-1).isfinite() + + HR, WR = rgb_tensor.shape[1:] + _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR) + pts3d_rescaled = torch.from_numpy(pts3d_rescaled) + valid_rescaled = torch.from_numpy(valid_rescaled) + view['pts3d_rescaled'] = pts3d_rescaled + view["valid_rescaled"] = valid_rescaled + views.append(view) + return views diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/sevenscenes.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/sevenscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..c15e851d262f0d7ba7071c933d8fe8f0a6b1c49d --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/sevenscenes.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# 7 Scenes dataloader +# -------------------------------------------------------- +import os +import numpy as np +import torch +import PIL.Image + +import kapture +from kapture.io.csv import kapture_from_dir +from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file +from kapture.io.records import depth_map_from_file + +from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d +from dust3r_visloc.datasets.base_dataset import BaseVislocDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, xy_grid, geotrf + + +class VislocSevenScenes(BaseVislocDataset): + def __init__(self, root, subscene, pairsfile, topk=1): + super().__init__() + self.root = root + self.subscene = subscene + self.topk = topk + self.num_views = self.topk + 1 + self.maxdim = None + self.patch_size = None + + query_path = os.path.join(self.root, subscene, 'query') + kdata_query = kapture_from_dir(query_path) + assert kdata_query.records_camera is not None and kdata_query.trajectories is not None and kdata_query.rigs is not None + kapture.rigs_remove_inplace(kdata_query.trajectories, kdata_query.rigs) + kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} + self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex} + + map_path = os.path.join(self.root, subscene, 'mapping') + kdata_map = kapture_from_dir(map_path) + assert kdata_map.records_camera is not None and kdata_map.trajectories is not None and kdata_map.rigs is not None + kapture.rigs_remove_inplace(kdata_map.trajectories, kdata_map.rigs) + kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) + for timestamp, sensor_id in kdata_map.records_camera.key_pairs()} + self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex} + + self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, subscene, + 'pairfiles/query', + pairsfile + '.txt')) + self.scenes = kdata_query.records_camera.data_list() + + def __len__(self): + return len(self.scenes) + + def __getitem__(self, idx): + assert self.maxdim is not None and self.patch_size is not None + query_image = self.scenes[idx] + map_images = [p[0] for p in self.pairs[query_image][:self.topk]] + views = [] + dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True) + for map_image in map_images] + for idx, (imgname, data, should_load_depth) in enumerate(dataarray): + imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex']) + + timestamp, camera_id = searchindex[imgname] + + # for 7scenes, SIMPLE_PINHOLE + camera_params = kdata.sensors[camera_id].camera_params + W, H, f, cx, cy = camera_params + distortion = [0, 0, 0, 0] + intrinsics = np.float32([(f, 0, cx), + (0, f, cy), + (0, 0, 1)]) + + cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) + + # Load RGB image + rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB') + rgb_image.load() + + W, H = rgb_image.size + resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) + + rgb_tensor = resize_func(ImgNorm(rgb_image)) + + view = { + 'intrinsics': intrinsics, + 'distortion': distortion, + 'cam_to_world': cam_to_world, + 'rgb': rgb_image, + 'rgb_rescaled': rgb_tensor, + 'to_orig': to_orig, + 'idx': idx, + 'image_name': imgname + } + + # Load depthmap + if should_load_depth: + depthmap_filename = os.path.join(imgpath, 'sensors/records_data', + imgname.replace('color.png', 'depth.reg')) + depthmap = depth_map_from_file(depthmap_filename, (int(W), int(H))).astype(np.float32) + pts3d_full, pts3d_valid = depthmap_to_absolute_camera_coordinates(depthmap, intrinsics, cam_to_world) + + pts3d = pts3d_full[pts3d_valid] + pts2d_int = xy_grid(W, H)[pts3d_valid] + pts2d = pts2d_int.astype(np.float64) + + # nan => invalid + pts3d_full[~pts3d_valid] = np.nan + pts3d_full = torch.from_numpy(pts3d_full) + view['pts3d'] = pts3d_full + view["valid"] = pts3d_full.sum(dim=-1).isfinite() + + HR, WR = rgb_tensor.shape[1:] + _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR) + pts3d_rescaled = torch.from_numpy(pts3d_rescaled) + valid_rescaled = torch.from_numpy(valid_rescaled) + view['pts3d_rescaled'] = pts3d_rescaled + view["valid_rescaled"] = valid_rescaled + views.append(view) + return views diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/utils.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6053ae2e5ba6c0b0f5f014161b666623d6e0f3f5 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/datasets/utils.py @@ -0,0 +1,118 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dataset utilities +# -------------------------------------------------------- +import numpy as np +import quaternion +import torchvision.transforms as tvf +from dust3r.utils.geometry import geotrf + + +def cam_to_world_from_kapture(kdata, timestamp, camera_id): + camera_to_world = kdata.trajectories[timestamp, camera_id].inverse() + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = quaternion.as_rotation_matrix(camera_to_world.r) + camera_pose[:3, 3] = camera_to_world.t_raw + return camera_pose + + +ratios_resolutions = { + 224: {1.0: [224, 224]}, + 512: {4 / 3: [512, 384], 32 / 21: [512, 336], 16 / 9: [512, 288], 2 / 1: [512, 256], 16 / 5: [512, 160]} +} + + +def get_HW_resolution(H, W, maxdim, patchsize=16): + assert maxdim in ratios_resolutions, "Error, maxdim can only be 224 or 512 for now. Other maxdims not implemented yet." + ratios_resolutions_maxdim = ratios_resolutions[maxdim] + mindims = set([min(res) for res in ratios_resolutions_maxdim.values()]) + ratio = W / H + ref_ratios = np.array([*(ratios_resolutions_maxdim.keys())]) + islandscape = (W >= H) + if islandscape: + diff = np.abs(ratio - ref_ratios) + else: + diff = np.abs(ratio - (1 / ref_ratios)) + selkey = ref_ratios[np.argmin(diff)] + res = ratios_resolutions_maxdim[selkey] + # check patchsize and make sure output resolution is a multiple of patchsize + if isinstance(patchsize, tuple): + assert len(patchsize) == 2 and isinstance(patchsize[0], int) and isinstance( + patchsize[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints." + assert patchsize[0] == patchsize[1], "Error, non square patches not managed" + patchsize = patchsize[0] + assert max(res) == maxdim + assert min(res) in mindims + return res[::-1] if islandscape else res # return HW + + +def get_resize_function(maxdim, patch_size, H, W, is_mask=False): + if [max(H, W), min(H, W)] in ratios_resolutions[maxdim].values(): + return lambda x: x, np.eye(3), np.eye(3) + else: + target_HW = get_HW_resolution(H, W, maxdim=maxdim, patchsize=patch_size) + + ratio = W / H + target_ratio = target_HW[1] / target_HW[0] + to_orig_crop = np.eye(3) + to_rescaled_crop = np.eye(3) + if abs(ratio - target_ratio) < np.finfo(np.float32).eps: + crop_W = W + crop_H = H + elif ratio - target_ratio < 0: + crop_W = W + crop_H = int(W / target_ratio) + to_orig_crop[1, 2] = (H - crop_H) / 2.0 + to_rescaled_crop[1, 2] = -(H - crop_H) / 2.0 + else: + crop_W = int(H * target_ratio) + crop_H = H + to_orig_crop[0, 2] = (W - crop_W) / 2.0 + to_rescaled_crop[0, 2] = - (W - crop_W) / 2.0 + + crop_op = tvf.CenterCrop([crop_H, crop_W]) + + if is_mask: + resize_op = tvf.Resize(size=target_HW, interpolation=tvf.InterpolationMode.NEAREST_EXACT) + else: + resize_op = tvf.Resize(size=target_HW) + to_orig_resize = np.array([[crop_W / target_HW[1], 0, 0], + [0, crop_H / target_HW[0], 0], + [0, 0, 1]]) + to_rescaled_resize = np.array([[target_HW[1] / crop_W, 0, 0], + [0, target_HW[0] / crop_H, 0], + [0, 0, 1]]) + + op = tvf.Compose([crop_op, resize_op]) + + return op, to_rescaled_resize @ to_rescaled_crop, to_orig_crop @ to_orig_resize + + +def rescale_points3d(pts2d, pts3d, to_resize, HR, WR): + # rescale pts2d as floats + # to colmap, so that the image is in [0, D] -> [0, NewD] + pts2d = pts2d.copy() + pts2d[:, 0] += 0.5 + pts2d[:, 1] += 0.5 + + pts2d_rescaled = geotrf(to_resize, pts2d, norm=True) + + pts2d_rescaled_int = pts2d_rescaled.copy() + # convert back to cv2 before round [-0.5, 0.5] -> pixel 0 + pts2d_rescaled_int[:, 0] -= 0.5 + pts2d_rescaled_int[:, 1] -= 0.5 + pts2d_rescaled_int = pts2d_rescaled_int.round().astype(np.int64) + + # update valid (remove cropped regions) + valid_rescaled = (pts2d_rescaled_int[:, 0] >= 0) & (pts2d_rescaled_int[:, 0] < WR) & ( + pts2d_rescaled_int[:, 1] >= 0) & (pts2d_rescaled_int[:, 1] < HR) + + pts2d_rescaled_int = pts2d_rescaled_int[valid_rescaled] + + # rebuild pts3d from rescaled ps2d poses + pts3d_rescaled = np.full((HR, WR, 3), np.nan, dtype=np.float32) # pts3d in 512 x something + pts3d_rescaled[pts2d_rescaled_int[:, 1], pts2d_rescaled_int[:, 0]] = pts3d[valid_rescaled] + + return pts2d_rescaled, pts2d_rescaled_int, pts3d_rescaled, np.isfinite(pts3d_rescaled.sum(axis=-1)) diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/evaluation.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..027179f2b1007db558f57d3d67f48a6d7aa1ab9d --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/evaluation.py @@ -0,0 +1,65 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# evaluation utilities +# -------------------------------------------------------- +import numpy as np +import quaternion +import torch +import roma +import collections +import os + + +def aggregate_stats(info_str, pose_errors, angular_errors): + stats = collections.Counter() + median_pos_error = np.median(pose_errors) + median_angular_error = np.median(angular_errors) + out_str = f'{info_str}: {len(pose_errors)} images - {median_pos_error=}, {median_angular_error=}' + + for trl_thr, ang_thr in [(0.1, 1), (0.25, 2), (0.5, 5), (5, 10)]: + for pose_error, angular_error in zip(pose_errors, angular_errors): + correct_for_this_threshold = (pose_error < trl_thr) and (angular_error < ang_thr) + stats[trl_thr, ang_thr] += correct_for_this_threshold + stats = {f'acc@{key[0]:g}m,{key[1]}deg': 100 * val / len(pose_errors) for key, val in stats.items()} + for metric, perf in stats.items(): + out_str += f' - {metric:12s}={float(perf):.3f}' + return out_str + + +def get_pose_error(pr_camtoworld, gt_cam_to_world): + abs_transl_error = torch.linalg.norm(torch.tensor(pr_camtoworld[:3, 3]) - torch.tensor(gt_cam_to_world[:3, 3])) + abs_angular_error = roma.rotmat_geodesic_distance(torch.tensor(pr_camtoworld[:3, :3]), + torch.tensor(gt_cam_to_world[:3, :3])) * 180 / np.pi + return abs_transl_error, abs_angular_error + + +def export_results(output_dir, xp_label, query_names, poses_pred): + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + + lines = "" + lines_ltvl = "" + for query_name, pr_querycam_to_world in zip(query_names, poses_pred): + if pr_querycam_to_world is None: + pr_world_to_querycam = np.eye(4) + else: + pr_world_to_querycam = np.linalg.inv(pr_querycam_to_world) + query_shortname = os.path.basename(query_name) + pr_world_to_querycam_q = quaternion.from_rotation_matrix(pr_world_to_querycam[:3, :3]) + pr_world_to_querycam_t = pr_world_to_querycam[:3, 3] + + line_pose = quaternion.as_float_array(pr_world_to_querycam_q).tolist() + \ + pr_world_to_querycam_t.flatten().tolist() + + line_content = [query_name] + line_pose + lines += ' '.join(str(v) for v in line_content) + '\n' + + line_content_ltvl = [query_shortname] + line_pose + lines_ltvl += ' '.join(str(v) for v in line_content_ltvl) + '\n' + + with open(os.path.join(output_dir, xp_label + '_results.txt'), 'wt') as f: + f.write(lines) + with open(os.path.join(output_dir, xp_label + '_ltvl.txt'), 'wt') as f: + f.write(lines_ltvl) diff --git a/imcui/third_party/mast3r/dust3r/dust3r_visloc/localization.py b/imcui/third_party/mast3r/dust3r/dust3r_visloc/localization.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8ae198dc3479f12a976bab0bda692328880710 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/dust3r_visloc/localization.py @@ -0,0 +1,140 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# main pnp code +# -------------------------------------------------------- +import numpy as np +import quaternion +import cv2 +from packaging import version + +from dust3r.utils.geometry import opencv_to_colmap_intrinsics + +try: + import poselib # noqa + HAS_POSELIB = True +except Exception as e: + HAS_POSELIB = False + +try: + import pycolmap # noqa + version_number = pycolmap.__version__ + if version.parse(version_number) < version.parse("0.5.0"): + HAS_PYCOLMAP = False + else: + HAS_PYCOLMAP = True +except Exception as e: + HAS_PYCOLMAP = False + +def run_pnp(pts2D, pts3D, K, distortion = None, mode='cv2', reprojectionError=5, img_size = None): + """ + use OPENCV model for distortion (4 values) + """ + assert mode in ['cv2', 'poselib', 'pycolmap'] + try: + if len(pts2D) > 4 and mode == "cv2": + confidence = 0.9999 + iterationsCount = 10_000 + if distortion is not None: + cv2_pts2ds = np.copy(pts2D) + cv2_pts2ds = cv2.undistortPoints(cv2_pts2ds, K, np.array(distortion), R=None, P=K) + pts2D = cv2_pts2ds.reshape((-1, 2)) + + success, r_pose, t_pose, _ = cv2.solvePnPRansac(pts3D, pts2D, K, None, flags=cv2.SOLVEPNP_SQPNP, + iterationsCount=iterationsCount, + reprojectionError=reprojectionError, + confidence=confidence) + if not success: + return False, None + r_pose = cv2.Rodrigues(r_pose)[0] # world2cam == world2cam2 + RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]] # world2cam2 + return True, np.linalg.inv(RT) # cam2toworld + elif len(pts2D) > 4 and mode == "poselib": + assert HAS_POSELIB + confidence = 0.9999 + iterationsCount = 10_000 + # NOTE: `Camera` struct currently contains `width`/`height` fields, + # however these are not used anywhere in the code-base and are provided simply to be consistent with COLMAP. + # so we put garbage in there + colmap_intrinsics = opencv_to_colmap_intrinsics(K) + fx = colmap_intrinsics[0, 0] + fy = colmap_intrinsics[1, 1] + cx = colmap_intrinsics[0, 2] + cy = colmap_intrinsics[1, 2] + width = img_size[0] if img_size is not None else int(cx*2) + height = img_size[1] if img_size is not None else int(cy*2) + + if distortion is None: + camera = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]} + else: + camera = {'model': 'OPENCV', 'width': width, 'height': height, + 'params': [fx, fy, cx, cy] + distortion} + + pts2D = np.copy(pts2D) + pts2D[:, 0] += 0.5 + pts2D[:, 1] += 0.5 + pose, _ = poselib.estimate_absolute_pose(pts2D, pts3D, camera, + {'max_reproj_error': reprojectionError, + 'max_iterations': iterationsCount, + 'success_prob': confidence}, {}) + if pose is None: + return False, None + RT = pose.Rt # (3x4) + RT = np.r_[RT, [(0,0,0,1)]] # world2cam + return True, np.linalg.inv(RT) # cam2toworld + elif len(pts2D) > 4 and mode == "pycolmap": + assert HAS_PYCOLMAP + assert img_size is not None + + pts2D = np.copy(pts2D) + pts2D[:, 0] += 0.5 + pts2D[:, 1] += 0.5 + colmap_intrinsics = opencv_to_colmap_intrinsics(K) + fx = colmap_intrinsics[0, 0] + fy = colmap_intrinsics[1, 1] + cx = colmap_intrinsics[0, 2] + cy = colmap_intrinsics[1, 2] + width = img_size[0] + height = img_size[1] + if distortion is None: + camera_dict = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]} + else: + camera_dict = {'model': 'OPENCV', 'width': width, 'height': height, + 'params': [fx, fy, cx, cy] + distortion} + + pycolmap_camera = pycolmap.Camera( + model=camera_dict['model'], width=camera_dict['width'], height=camera_dict['height'], + params=camera_dict['params']) + + pycolmap_estimation_options = dict(ransac=dict(max_error=reprojectionError, min_inlier_ratio=0.01, + min_num_trials=1000, max_num_trials=100000, + confidence=0.9999)) + pycolmap_refinement_options=dict(refine_focal_length=False, refine_extra_params=False) + ret = pycolmap.absolute_pose_estimation(pts2D, pts3D, pycolmap_camera, + estimation_options=pycolmap_estimation_options, + refinement_options=pycolmap_refinement_options) + if ret is None: + ret = {'success': False} + else: + ret['success'] = True + if callable(ret['cam_from_world'].matrix): + retmat = ret['cam_from_world'].matrix() + else: + retmat = ret['cam_from_world'].matrix + ret['qvec'] = quaternion.from_rotation_matrix(retmat[:3, :3]) + ret['tvec'] = retmat[:3, 3] + + if not (ret['success'] and ret['num_inliers'] > 0): + success = False + pose = None + else: + success = True + pr_world_to_querycam = np.r_[ret['cam_from_world'].matrix(), [(0,0,0,1)]] + pose = np.linalg.inv(pr_world_to_querycam) + return success, pose + else: + return False, None + except Exception as e: + print(f'error during pnp: {e}') + return False, None \ No newline at end of file diff --git a/imcui/third_party/mast3r/dust3r/train.py b/imcui/third_party/mast3r/dust3r/train.py new file mode 100644 index 0000000000000000000000000000000000000000..503e63572376c259e6b259850e19c3f6036aa535 --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/train.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# training executable for DUSt3R +# -------------------------------------------------------- +from dust3r.training import get_args_parser, train + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + train(args) diff --git a/imcui/third_party/mast3r/dust3r/visloc.py b/imcui/third_party/mast3r/dust3r/visloc.py new file mode 100644 index 0000000000000000000000000000000000000000..6411b3eaf96dea961f9524e887a12d92f2012c6b --- /dev/null +++ b/imcui/third_party/mast3r/dust3r/visloc.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Simple visloc script +# -------------------------------------------------------- +import numpy as np +import random +import argparse +from tqdm import tqdm +import math + +from dust3r.inference import inference +from dust3r.model import AsymmetricCroCo3DStereo +from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf + +from dust3r_visloc.datasets import * +from dust3r_visloc.localization import run_pnp +from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval") + parser_weights = parser.add_mutually_exclusive_group(required=True) + parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) + parser_weights.add_argument("--model_name", type=str, help="name of the model weights", + choices=["DUSt3R_ViTLarge_BaseDecoder_512_dpt", + "DUSt3R_ViTLarge_BaseDecoder_512_linear", + "DUSt3R_ViTLarge_BaseDecoder_224_linear"]) + parser.add_argument("--confidence_threshold", type=float, default=3.0, + help="confidence values higher than threshold are invalid") + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'], + help="pnp lib to use") + parser_reproj = parser.add_mutually_exclusive_group() + parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error") + parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None, + help="pnp reprojection error as a ratio of the diagonal of the image") + + parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept") + parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") + + parser.add_argument("--output_dir", type=str, default=None, help="output path") + parser.add_argument("--output_label", type=str, default='', help="prefix for results files") + return parser + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + conf_thr = args.confidence_threshold + device = args.device + pnp_mode = args.pnp_mode + reprojection_error = args.reprojection_error + reprojection_error_diag_ratio = args.reprojection_error_diag_ratio + pnp_max_points = args.pnp_max_points + viz_matches = args.viz_matches + + if args.weights is not None: + weights_path = args.weights + else: + weights_path = "naver/" + args.model_name + model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) + + dataset = eval(args.dataset) + dataset.set_resolution(model) + + query_names = [] + poses_pred = [] + pose_errors = [] + angular_errors = [] + for idx in tqdm(range(len(dataset))): + views = dataset[(idx)] # 0 is the query + query_view = views[0] + map_views = views[1:] + query_names.append(query_view['image_name']) + + query_pts2d = [] + query_pts3d = [] + for map_view in map_views: + # prepare batch + imgs = [] + for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]): + imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]), + idx=idx, instance=str(idx))) + output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False) + pred1, pred2 = output['pred1'], output['pred2'] + confidence_masks = [pred1['conf'].squeeze(0) >= conf_thr, + (pred2['conf'].squeeze(0) >= conf_thr) & map_view['valid_rescaled']] + pts3d = [pred1['pts3d'].squeeze(0), pred2['pts3d_in_other_view'].squeeze(0)] + + # find 2D-2D matches between the two images + pts2d_list, pts3d_list = [], [] + for i in range(2): + conf_i = confidence_masks[i].cpu().numpy() + true_shape_i = imgs[i]['true_shape'][0] + pts2d_list.append(xy_grid(true_shape_i[1], true_shape_i[0])[conf_i]) + pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) + + PQ, PM = pts3d_list[0], pts3d_list[1] + if len(PQ) == 0 or len(PM) == 0: + continue + reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(PQ, PM) + if viz_matches > 0: + print(f'found {num_matches} matches') + matches_im1 = pts2d_list[1][reciprocal_in_PM] + matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] + valid_pts3d = map_view['pts3d_rescaled'][matches_im1[:, 1], matches_im1[:, 0]] + + # from cv2 to colmap + matches_im0 = matches_im0.astype(np.float64) + matches_im1 = matches_im1.astype(np.float64) + matches_im0[:, 0] += 0.5 + matches_im0[:, 1] += 0.5 + matches_im1[:, 0] += 0.5 + matches_im1[:, 1] += 0.5 + # rescale coordinates + matches_im0 = geotrf(query_view['to_orig'], matches_im0, norm=True) + matches_im1 = geotrf(query_view['to_orig'], matches_im1, norm=True) + # from colmap back to cv2 + matches_im0[:, 0] -= 0.5 + matches_im0[:, 1] -= 0.5 + matches_im1[:, 0] -= 0.5 + matches_im1[:, 1] -= 0.5 + + # visualize a few matches + if viz_matches > 0: + viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] + from matplotlib import pyplot as pl + n_viz = viz_matches + match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) + viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] + + H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] + img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) + img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) + img = np.concatenate((img0, img1), axis=1) + pl.figure() + pl.imshow(img) + cmap = pl.get_cmap('jet') + for i in range(n_viz): + (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T + pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) + pl.show(block=True) + + if len(valid_pts3d) == 0: + pass + else: + query_pts3d.append(valid_pts3d.cpu().numpy()) + query_pts2d.append(matches_im0) + + if len(query_pts2d) == 0: + success = False + pr_querycam_to_world = None + else: + query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) + query_pts3d = np.concatenate(query_pts3d, axis=0) + if len(query_pts2d) > pnp_max_points: + idxs = random.sample(range(len(query_pts2d)), pnp_max_points) + query_pts3d = query_pts3d[idxs] + query_pts2d = query_pts2d[idxs] + + W, H = query_view['rgb'].size + if reprojection_error_diag_ratio is not None: + reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2) + else: + reprojection_error_img = reprojection_error + success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d, + query_view['intrinsics'], query_view['distortion'], + pnp_mode, reprojection_error_img, img_size=[W, H]) + + if not success: + abs_transl_error = float('inf') + abs_angular_error = float('inf') + else: + abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world']) + + pose_errors.append(abs_transl_error) + angular_errors.append(abs_angular_error) + poses_pred.append(pr_querycam_to_world) + + xp_label = f'tol_conf_{conf_thr}' + if args.output_label: + xp_label = args.output_label + '_' + xp_label + if reprojection_error_diag_ratio is not None: + xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}' + else: + xp_label = xp_label + f'_reproj_err_{reprojection_error}' + export_results(args.output_dir, xp_label, query_names, poses_pred) + out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors) + print(out_string) diff --git a/imcui/third_party/mast3r/mast3r/__init__.py b/imcui/third_party/mast3r/mast3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dd877d649ce4dbd749dd7195a8b34c0f91d4f0 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). \ No newline at end of file diff --git a/imcui/third_party/mast3r/mast3r/catmlp_dpt_head.py b/imcui/third_party/mast3r/mast3r/catmlp_dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4457908f97e7e25c4c59cc696fb059791fbff8 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/catmlp_dpt_head.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# MASt3R heads +# -------------------------------------------------------- +import torch +import torch.nn.functional as F + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa +from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa +import dust3r.utils.path_to_croco # noqa +from models.blocks import Mlp # noqa + + +def reg_desc(desc, mode): + if 'norm' in mode: + desc = desc / desc.norm(dim=-1, keepdim=True) + else: + raise ValueError(f"Unknown desc mode {mode}") + return desc + + +def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None): + if desc_conf_mode is None: + desc_conf_mode = conf_mode + fmap = out.permute(0, 2, 3, 1) # B,H,W,D + res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode)) + if conf_mode is not None: + res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode) + if desc_dim is not None: + start = 3 + int(conf_mode is not None) + res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode) + if two_confs: + res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode) + else: + res['desc_conf'] = res['conf'].clone() + return res + + +class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT): + """ Mixture between MLP and DPT head that outputs 3d points and local features (with MLP). + The input for both heads is a concatenation of Encoder and Decoder outputs + """ + + def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None, + num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs): + super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx, + dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type) + self.local_feat_dim = local_feat_dim + + patch_size = net.patch_embed.patch_size + if isinstance(patch_size, tuple): + assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance( + patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints." + assert patch_size[0] == patch_size[1], "Error, non square patches not managed" + patch_size = patch_size[0] + self.patch_size = patch_size + + self.desc_mode = net.desc_mode + self.has_conf = has_conf + self.two_confs = net.two_confs # independent confs for 3D regr and descs + self.desc_conf_mode = net.desc_conf_mode + idim = net.enc_embed_dim + net.dec_embed_dim + + self.head_local_features = Mlp(in_features=idim, + hidden_features=int(hidden_dim_factor * idim), + out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2) + + def forward(self, decout, img_shape): + # pass through the heads + pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1])) + + # recover encoder and decoder outputs + enc_output, dec_output = decout[0], decout[-1] + cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate + H, W = img_shape + B, S, D = cat_output.shape + + # extract local_features + local_features = self.head_local_features(cat_output) # B,S,D + local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size) + local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W + + # post process 3D pts, descriptors and confidences + out = torch.cat([pts3d, local_features], dim=1) + if self.postprocess: + out = self.postprocess(out, + depth_mode=self.depth_mode, + conf_mode=self.conf_mode, + desc_dim=self.local_feat_dim, + desc_mode=self.desc_mode, + two_confs=self.two_confs, + desc_conf_mode=self.desc_conf_mode) + return out + + +def mast3r_head_factory(head_type, output_mode, net, has_conf=False): + """" build a prediction head for the decoder + """ + if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'): + local_feat_dim = int(output_mode[10:]) + assert net.dec_depth > 9 + l2 = net.dec_depth + feature_dim = 256 + last_dim = feature_dim // 2 + out_nchan = 3 + ed = net.enc_embed_dim + dd = net.dec_embed_dim + return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, + num_channels=out_nchan + has_conf, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], + dim_tokens=[ed, dd, dd, dd], + postprocess=postprocess, + depth_mode=net.depth_mode, + conf_mode=net.conf_mode, + head_type='regression') + else: + raise NotImplementedError( + f"unexpected {head_type=} and {output_mode=}") diff --git a/imcui/third_party/mast3r/mast3r/cloud_opt/__init__.py b/imcui/third_party/mast3r/mast3r/cloud_opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dd877d649ce4dbd749dd7195a8b34c0f91d4f0 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/cloud_opt/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). \ No newline at end of file diff --git a/imcui/third_party/mast3r/mast3r/cloud_opt/sparse_ga.py b/imcui/third_party/mast3r/mast3r/cloud_opt/sparse_ga.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1eb6b4d264e458d4efdc4e50281f1d0c7c4012 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/cloud_opt/sparse_ga.py @@ -0,0 +1,1040 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# MASt3R Sparse Global Alignement +# -------------------------------------------------------- +from tqdm import tqdm +import roma +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import os +from collections import namedtuple +from functools import lru_cache +from scipy import sparse as sp +import copy + +from mast3r.utils.misc import mkdir_for, hash_md5 +from mast3r.cloud_opt.utils.losses import gamma_loss +from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule +from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.utils.geometry import inv, geotrf # noqa +from dust3r.utils.device import to_cpu, to_numpy, todevice # noqa +from dust3r.post_process import estimate_focal_knowing_depth # noqa +from dust3r.optim_factory import adjust_learning_rate_by_lr # noqa +from dust3r.cloud_opt.base_opt import clean_pointcloud +from dust3r.viz import SceneViz + + +class SparseGA(): + def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None): + def fetch_img(im): + def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.) + for im1, im2 in pairs_in: + if im1['instance'] == im: + return torgb(im1['img']) + if im2['instance'] == im: + return torgb(im2['img']) + self.canonical_paths = canonical_paths + self.img_paths = img_paths + self.imgs = [fetch_img(img) for img in img_paths] + self.intrinsics = res_fine['intrinsics'] + self.cam2w = res_fine['cam2w'] + self.depthmaps = res_fine['depthmaps'] + self.pts3d = res_fine['pts3d'] + self.pts3d_colors = [] + self.working_device = self.cam2w.device + for i in range(len(self.imgs)): + im = self.imgs[i] + x, y = anchors[i][0][..., :2].detach().cpu().numpy().T + self.pts3d_colors.append(im[y, x]) + assert self.pts3d_colors[-1].shape == self.pts3d[i].shape + self.n_imgs = len(self.imgs) + + def get_focals(self): + return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device) + + def get_principal_points(self): + return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device) + + def get_im_poses(self): + return self.cam2w + + def get_sparse_pts3d(self): + return self.pts3d + + def get_dense_pts3d(self, clean_depth=True, subsample=8): + assert self.canonical_paths, 'cache_path is required for dense 3d points' + device = self.cam2w.device + confs = [] + base_focals = [] + anchors = {} + for i, canon_path in enumerate(self.canonical_paths): + (canon, canon2, conf), focal = torch.load(canon_path, map_location=device) + confs.append(conf) + base_focals.append(focal) + + H, W = conf.shape + pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device) + idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample) + anchors[i] = (pixels, idxs[i], offsets[i]) + + # densify sparse depthmaps + pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [ + d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True) + + if clean_depth: + confs = clean_pointcloud(confs, self.intrinsics, inv(self.cam2w), depthmaps, pts3d) + + return pts3d, depthmaps, confs + + def get_pts3d_colors(self): + return self.pts3d_colors + + def get_depthmaps(self): + return self.depthmaps + + def get_masks(self): + return [slice(None, None) for _ in range(len(self.imgs))] + + def show(self, show_cams=True): + pts3d, _, confs = self.get_dense_pts3d() + show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w, + [p.clip(min=-50, max=50) for p in pts3d], + masks=[c > 1 for c in confs]) + + +def convert_dust3r_pairs_naming(imgs, pairs_in): + for pair_id in range(len(pairs_in)): + for i in range(2): + pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']] + return pairs_in + + +def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf', + device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw): + """ Sparse alignment with MASt3R + imgs: list of image paths + cache_path: path where to dump temporary files (str) + + lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching) + lr2, niter2: learning rate and #iterations for refinement (2D reproj error) + + lora_depth: smart dimensionality reduction with depthmaps + """ + # Convert pair naming convention from dust3r to mast3r + pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in) + # forward pass + pairs, cache_path = forward_mast3r(pairs_in, model, + cache_path=cache_path, subsample=subsample, + desc_conf=desc_conf, device=device) + + # extract canonical pointmaps + tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \ + prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device) + + # compute minimal spanning tree + mst = compute_min_spanning_tree(pairwise_scores) + + # remove all edges not in the spanning tree? + # min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]} + # tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree} + + # smartly combine all useful data + imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \ + condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype) + + imgs, res_coarse, res_fine = sparse_scene_optimizer( + imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst, + shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw) + + return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths) + + +def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, + preds_21, canonical_paths, mst, cache_path, + lr1=0.2, niter1=500, loss1=gamma_loss(1.1), + lr2=0.02, niter2=500, loss2=gamma_loss(0.4), + lossd=gamma_loss(1.1), + opt_pp=True, opt_depth=True, + schedule=cosine_schedule, depth_mode='add', exp_depth=False, + lora_depth=False, # dict(k=96, gamma=15, min_norm=.5), + shared_intrinsics=False, + init={}, device='cuda', dtype=torch.float32, + matching_conf_thr=5., loss_dust3r_w=0.01, + verbose=True, dbg=()): + init = copy.deepcopy(init) + # extrinsic parameters + vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device) + quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))] + trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))] + + # initialize + ones = torch.ones((len(imgs), 1), device=device, dtype=dtype) + median_depths = torch.ones(len(imgs), device=device, dtype=dtype) + for img in imgs: + idx = imgs.index(img) + init_values = init.setdefault(img, {}) + if verbose and init_values: + print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}') + + K = init_values.get('intrinsics') + if K is not None: + K = K.detach() + focal = K[:2, :2].diag().mean() + pp = K[:2, 2] + base_focals[idx] = focal + pps[idx] = pp + pps[idx] /= imsizes[idx] # default principal_point would be (0.5, 0.5) + + depth = init_values.get('depthmap') + if depth is not None: + core_depth[idx] = depth.detach() + + median_depths[idx] = med_depth = core_depth[idx].median() + core_depth[idx] /= med_depth + + cam2w = init_values.get('cam2w') + if cam2w is not None: + rot = cam2w[:3, :3].detach() + cam_center = cam2w[:3, 3].detach() + quats[idx].data[:] = roma.rotmat_to_unitquat(rot) + trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0])) + trans[idx].data[:] = cam_center + rot @ trans_offset + del rot + assert False, 'inverse kinematic chain not yet implemented' + + # intrinsics parameters + if shared_intrinsics: + # Optimize a single set of intrinsics for all cameras. Use averages as init. + confs = torch.stack([torch.load(pth)[0][2].mean() for pth in canonical_paths]).to(pps) + weighting = confs / confs.sum() + pp = nn.Parameter((weighting @ pps).to(dtype)) + pps = [pp for _ in range(len(imgs))] + focal_m = weighting @ base_focals + log_focal = nn.Parameter(focal_m.view(1).log().to(dtype)) + log_focals = [log_focal for _ in range(len(imgs))] + else: + pps = [nn.Parameter(pp.to(dtype)) for pp in pps] + log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals] + + diags = imsizes.float().norm(dim=1) + min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26 + max_focals = 10 * diags + + assert len(mst[1]) == len(pps) - 1 + + def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth): + # make intrinsics + focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals) + pps = torch.stack(pps) + K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone() + K[:, 0, 0] = K[:, 1, 1] = focals + K[:, 0:2, 2] = pps * imsizes + if trans is None: + return K + + # security! optimization is always trying to crush the scale down + sizes = torch.cat(log_sizes).exp() + global_scaling = 1 / sizes.min() + + # compute distance of camera to focal plane + # tan(fov) = W/2 / focal + z_cameras = sizes * median_depths * focals / base_focals + + # make extrinsic + rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone() + rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1)) + rel_cam2cam[:, :3, 3] = torch.stack(trans) + + # camera are defined as a kinematic chain + tmp_cam2w = [None] * len(K) + tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]] + for i, j in mst[1]: + # i is the cam_i_to_world reference, j is the relative pose = cam_j_to_cam_i + tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j] + tmp_cam2w = torch.stack(tmp_cam2w) + + # smart reparameterizaton of cameras + trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1) + new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1)) + cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2), + vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1) + + depthmaps = [] + for i in range(len(imgs)): + core_depth_img = core_depth[i] + if exp_depth: + core_depth_img = core_depth_img.exp() + if lora_depth: # compute core_depth as a low-rank decomposition of 3d points + core_depth_img = lora_depth_proj[i] @ core_depth_img + if depth_mode == 'add': + core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i]) + elif depth_mode == 'mul': + core_depth_img = z_cameras[i] * core_depth_img + else: + raise ValueError(f'Bad {depth_mode=}') + depthmaps.append(global_scaling * core_depth_img) + + return K, (inv(cam2w), cam2w), depthmaps + + K = make_K_cam_depth(log_focals, pps, None, None, None, None) + + if shared_intrinsics: + print('init focal (shared) = ', to_numpy(K[0, 0, 0]).round(2)) + else: + print('init focals =', to_numpy(K[:, 0, 0])) + + # spectral low-rank projection of depthmaps + if lora_depth: + core_depth, lora_depth_proj = spectral_projection_of_depthmaps( + imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth) + if exp_depth: + core_depth = [d.clip(min=1e-4).log() for d in core_depth] + core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth] + log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))] + + # Fetch img slices + _, confs_sum, imgs_slices = corres + + # Define which pairs are fine to use with matching + def matching_check(x): return x.max() > matching_conf_thr + is_matching_ok = {} + for s in imgs_slices: + is_matching_ok[s.img1, s.img2] = matching_check(s.confs) + + # Prepare slices and corres for losses + dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]] + loss3d_slices = [s for s in imgs_slices if is_matching_ok[s.img1, s.img2]] + cleaned_corres2d = [] + for cci, (img1, pix1, confs, confsum, imgs_slices) in enumerate(corres2d): + cf_sum = 0 + pix1_filtered = [] + confs_filtered = [] + curstep = 0 + cleaned_slices = [] + for img2, slice2 in imgs_slices: + if is_matching_ok[img1, img2]: + tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step) + pix1_filtered.append(pix1[tslice]) + confs_filtered.append(confs[tslice]) + cleaned_slices.append((img2, slice2)) + curstep += slice2.stop - slice2.start + if pix1_filtered != []: + pix1_filtered = torch.cat(pix1_filtered) + confs_filtered = torch.cat(confs_filtered) + cf_sum = confs_filtered.sum() + cleaned_corres2d.append((img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices)) + + def loss_dust3r(cam2w, pts3d, pix_loss): + # In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified) + loss = 0. + cf_sum = 0. + for s in dust3r_slices: + if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'): + continue + # fallback to dust3r regression + tgt_pts, tgt_confs = preds_21[imgs[s.img2]][imgs[s.img1]] + tgt_pts = geotrf(cam2w[s.img2], tgt_pts) + cf_sum += tgt_confs.sum() + loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts) + return loss / cf_sum if cf_sum != 0. else 0. + + def loss_3d(K, w2cam, pts3d, pix_loss): + # For each correspondence, we have two 3D points (one for each image of the pair). + # For each 3D point, we have 2 reproj errors + if any(v.get('freeze') for v in init.values()): + pts3d_1 = [] + pts3d_2 = [] + confs = [] + for s in loss3d_slices: + if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'): + continue + pts3d_1.append(pts3d[s.img1][s.slice1]) + pts3d_2.append(pts3d[s.img2][s.slice2]) + confs.append(s.confs) + else: + pts3d_1 = [pts3d[s.img1][s.slice1] for s in loss3d_slices] + pts3d_2 = [pts3d[s.img2][s.slice2] for s in loss3d_slices] + confs = [s.confs for s in loss3d_slices] + + if pts3d_1 != []: + confs = torch.cat(confs) + pts3d_1 = torch.cat(pts3d_1) + pts3d_2 = torch.cat(pts3d_2) + loss = confs @ pix_loss(pts3d_1, pts3d_2) + cf_sum = confs.sum() + else: + loss = 0. + cf_sum = 1. + + return loss / cf_sum + + def loss_2d(K, w2cam, pts3d, pix_loss): + # For each correspondence, we have two 3D points (one for each image of the pair). + # For each 3D point, we have 2 reproj errors + proj_matrix = K @ w2cam[:, :3] + loss = npix = 0 + for img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices in cleaned_corres2d: + if init[imgs[img1]].get('freeze', 0) >= 1: + continue # no need + pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in cleaned_slices] + if pts3d_in_img1 != []: + pts3d_in_img1 = torch.cat(pts3d_in_img1) + loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1)) + npix += confs_filtered.sum() + + return loss / npix if npix != 0 else 0. + + def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0): + # create optimizer + params = pps + log_focals + quats + trans + log_sizes + core_depth + optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9)) + ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss) + + with tqdm(total=niter) as bar: + for iter in range(niter or 1): + K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth) + pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals) + if niter == 0: + break + + alpha = (iter / niter) + lr = schedule(alpha, lr_base, lr_end) + adjust_learning_rate_by_lr(optimizer, lr) + pix_loss = ploss(1 - alpha) + optimizer.zero_grad() + loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd) + loss.backward() + optimizer.step() + + # make sure the pose remains well optimizable + for i in range(len(imgs)): + quats[i].data[:] /= quats[i].data.norm() + + loss = float(loss) + if loss != loss: + break # NaN loss + bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}') + bar.update(1) + + if niter: + print(f'>> final loss = {loss}') + return dict(intrinsics=K.detach(), cam2w=cam2w.detach(), + depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d]) + + # at start, don't optimize 3d points + for i, img in enumerate(imgs): + trainable = not (init[img].get('freeze')) + pps[i].requires_grad_(False) + log_focals[i].requires_grad_(False) + quats[i].requires_grad_(trainable) + trans[i].requires_grad_(trainable) + log_sizes[i].requires_grad_(trainable) + core_depth[i].requires_grad_(False) + + res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1) + + res_fine = None + if niter2: + # now we can optimize 3d points + for i, img in enumerate(imgs): + if init[img].get('freeze', 0) >= 1: + continue + pps[i].requires_grad_(bool(opt_pp)) + log_focals[i].requires_grad_(True) + core_depth[i].requires_grad_(opt_depth) + + # refinement with 2d reproj + res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2) + + K = make_K_cam_depth(log_focals, pps, None, None, None, None) + if shared_intrinsics: + print('Final focal (shared) = ', to_numpy(K[0, 0, 0]).round(2)) + else: + print('Final focals =', to_numpy(K[:, 0, 0])) + + return imgs, res_coarse, res_fine + + +@lru_cache +def mask110(device, dtype): + return torch.tensor((1, 1, 0), device=device, dtype=dtype) + + +def proj3d(inv_K, pixels, z): + if pixels.shape[-1] == 2: + pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1) + return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype)) + + +def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False): + focals = K[:, 0, 0] + invK = inv(K) + all_pts3d = [] + depth_out = [] + + for img, (pixels, idxs, offsets) in anchors.items(): + # from depthmaps to 3d points + if base_focals is None: + pass + else: + # compensate for focal + # depth + depth * (offset - 1) * base_focal / focal + # = depth * (1 + (offset - 1) * (base_focal / focal)) + offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img]) + + pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets) + if ret_depth: + depth_out.append(pts3d[..., 2]) # before camera rotation + + # rotate to world coordinate + pts3d = geotrf(cam2w[img], pts3d) + all_pts3d.append(pts3d) + + if ret_depth: + return all_pts3d, depth_out + return all_pts3d + + +def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'): + base_focals = [] + anchors = {} + confs = [] + for i, canon_path in enumerate(canonical_paths): + (canon, canon2, conf), focal = torch.load(canon_path, map_location=device) + confs.append(conf) + base_focals.append(focal) + H, W = conf.shape + pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device) + idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample) + anchors[i] = (pixels, idxs[i], offsets[i]) + + # densify sparse depthmaps + pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [ + d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True) + + return pts3d, depthmaps_out, confs + + +@torch.no_grad() +def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf', + device='cuda', subsample=8, **matching_kw): + res_paths = {} + + for img1, img2 in tqdm(pairs): + idx1 = hash_md5(img1['instance']) + idx2 = hash_md5(img2['instance']) + + path1 = cache_path + f'/forward/{idx1}/{idx2}.pth' + path2 = cache_path + f'/forward/{idx2}/{idx1}.pth' + path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth' + path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth' + + if os.path.isfile(path_corres2) and not os.path.isfile(path_corres): + score, (xy1, xy2, confs) = torch.load(path_corres2) + torch.save((score, (xy2, xy1, confs)), path_corres) + + if not all(os.path.isfile(p) for p in (path1, path2, path_corres)): + if model is None: + continue + res = symmetric_inference(model, img1, img2, device=device) + X11, X21, X22, X12 = [r['pts3d'][0] for r in res] + C11, C21, C22, C12 = [r['conf'][0] for r in res] + descs = [r['desc'][0] for r in res] + qonfs = [r[desc_conf][0] for r in res] + + # save + torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1)) + torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2)) + + # perform reciprocal matching + corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample) + + conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt() + matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2])) + if cache_path is not None: + torch.save((matching_score, corres), mkdir_for(path_corres)) + + res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres + + del model + torch.cuda.empty_cache() + + return res_paths, cache_path + + +def symmetric_inference(model, img1, img2, device): + shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True) + shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True) + img1 = img1['img'].to(device, non_blocking=True) + img2 = img2['img'].to(device, non_blocking=True) + + # compute encoder only once + feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2) + + def decoder(feat1, feat2, pos1, pos2, shape1, shape2): + dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2) + with torch.cuda.amp.autocast(enabled=False): + res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2) + return res1, res2 + + # decoder 1-2 + res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2) + # decoder 2-1 + res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1) + + return (res11, res21, res22, res12) + + +def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'): + feat11, feat21, feat22, feat12 = feats + qonf11, qonf21, qonf22, qonf12 = qonfs + assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape + assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape + + if '3d' in ptmap_key: + opt = dict(device='cpu', workers=32) + else: + opt = dict(device=device, dist='dot', block_size=2**13) + + # matching the two pairs + idx1 = [] + idx2 = [] + qonf1 = [] + qonf2 = [] + # TODO add non symmetric / pixel_tol options + for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()), + (feat12, feat22, qonf12.cpu(), qonf22.cpu())]: + nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt) + nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt) + + idx1.append(np.r_[nn1to2[0], nn2to1[1]]) + idx2.append(np.r_[nn1to2[1], nn2to1[0]]) + qonf1.append(QA.ravel()[idx1[-1]]) + qonf2.append(QB.ravel()[idx2[-1]]) + + # merge corres from opposite pairs + H1, W1 = feat11.shape[:2] + H2, W2 = feat22.shape[:2] + cat = np.concatenate + + xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True) + corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx])) + + return todevice(corres, device) + + +@torch.no_grad() +def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0, + cache_path=None, device='cuda', **kw): + canonical_views = {} + pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device) + canonical_paths = [] + preds_21 = {} + + for img in tqdm(imgs): + if cache_path: + cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth') + canonical_paths.append(cache) + try: + (canon, canon2, cconf), focal = torch.load(cache, map_location=device) + except IOError: + # cache does not exist yet, we create it! + canon = focal = None + + # collect all pred1 + n_pairs = sum((img in pair) for pair in tmp_pairs) + + ptmaps11 = None + pixels = {} + n = 0 + for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items(): + score = None + if img == img1: + X, C, X2, C2 = torch.load(path1, map_location=device) + score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr) + pixels[img2] = xy1, confs + if img not in preds_21: + preds_21[img] = {} + # Subsample preds_21 + preds_21[img][img2] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel() + + if img == img2: + X, C, X2, C2 = torch.load(path2, map_location=device) + score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr) + pixels[img1] = xy2, confs + if img not in preds_21: + preds_21[img] = {} + preds_21[img][img1] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel() + + if score is not None: + i, j = imgs.index(img1), imgs.index(img2) + # score = score[0] + # score = np.log1p(score[2]) + score = score[2] + pairwise_scores[i, j] = score + pairwise_scores[j, i] = score + + if canon is not None: + continue + if ptmaps11 is None: + H, W = C.shape + ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device) + confs11 = torch.empty((n_pairs, H, W), device=device) + + ptmaps11[n] = X + confs11[n] = C + n += 1 + + if canon is None: + canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw) + del ptmaps11 + del confs11 + + # compute focals + H, W = canon.shape[:2] + pp = torch.tensor([W / 2, H / 2], device=device) + if focal is None: + focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5) + if cache: + torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache)) + + # extract depth offsets with correspondences + core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2] + idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample) + + canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets) + + return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 + + +def load_corres(path_corres, device, min_conf_thr): + score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device) + valid = confs > min_conf_thr if min_conf_thr else slice(None) + # valid = (xy1 > 0).all(dim=1) & (xy2 > 0).all(dim=1) & (xy1 < 512).all(dim=1) & (xy2 < 512).all(dim=1) + # print(f'keeping {valid.sum()} / {len(valid)} correspondences') + return score, (xy1[valid], xy2[valid], confs[valid]) + + +PairOfSlices = namedtuple( + 'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum') + + +def condense_data(imgs, tmp_paths, canonical_views, preds_21, dtype=torch.float32): + # aggregate all data properly + set_imgs = set(imgs) + + principal_points = [] + shapes = [] + focals = [] + core_depth = [] + img_anchors = {} + tmp_pixels = {} + + for idx1, img1 in enumerate(imgs): + # load stuff + pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1] + + principal_points.append(pp) + shapes.append(shape) + focals.append(focal) + core_depth.append(anchors) + + img_uv1 = [] + img_idxs = [] + img_offs = [] + cur_n = [0] + + for img2, (pixels, match_confs) in pixels_confs.items(): + if img2 not in set_imgs: + continue + assert len(pixels) == len(idxs[img2]) == len(offsets[img2]) + img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1)) + img_idxs.append(idxs[img2]) + img_offs.append(offsets[img2]) + cur_n.append(cur_n[-1] + len(pixels)) + # store the position of 3d points + tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:]) + img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs)) + + all_confs = [] + imgs_slices = [] + corres2d = {img: [] for img in range(len(imgs))} + + for img1, img2 in tmp_paths: + try: + pix1, confs1, slice1 = tmp_pixels[img1, img2] + pix2, confs2, slice2 = tmp_pixels[img2, img1] + except KeyError: + continue + img1 = imgs.index(img1) + img2 = imgs.index(img2) + confs = (confs1 * confs2).sqrt() + + # prepare for loss_3d + all_confs.append(confs) + anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]] + anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]] + imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1, + img2, slice2, pix2, anchor_idxs2, + confs, float(confs.sum()))) + + # prepare for loss_2d + corres2d[img1].append((pix1, confs, img2, slice2)) + corres2d[img2].append((pix2, confs, img1, slice1)) + + all_confs = torch.cat(all_confs) + corres = (all_confs, float(all_confs.sum()), imgs_slices) + + def aggreg_matches(img1, list_matches): + pix1, confs, img2, slice2 = zip(*list_matches) + all_pix1 = torch.cat(pix1).to(dtype) + all_confs = torch.cat(confs).to(dtype) + return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)] + corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()] + + imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) # (W,H) + principal_points = torch.stack(principal_points) + focals = torch.cat(focals) + + # Subsample preds_21 + subsamp_preds_21 = {} + for imk, imv in preds_21.items(): + subsamp_preds_21[imk] = {} + for im2k, (pred, conf) in preds_21[imk].items(): + idxs = img_anchors[imgs.index(im2k)][1] + subsamp_preds_21[imk][im2k] = (pred[idxs], conf[idxs]) # anchors subsample + + return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d, subsamp_preds_21 + + +def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'): + assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}' + + # canonical pointmap is just a weighted average + confs11 = confs11.unsqueeze(-1) - 0.999 + canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0) + + canon_depth = ptmaps11[..., 2].unsqueeze(1) + S = slice(subsample // 2, None, subsample) + center_depth = canon_depth[:, :, S, S] + center_depth = torch.clip(center_depth, min=torch.finfo(center_depth.dtype).eps) + + stacked_depth = F.pixel_unshuffle(canon_depth, subsample) + stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample) + + if mode == 'avg-reldepth': + rel_depth = stacked_depth / center_depth + stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0) + canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze() + + elif mode == 'avg-angle': + xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2) + stacked_xy = F.pixel_unshuffle(xy, subsample) + B, _, H, W = stacked_xy.shape + stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1) + stacked_radius.clip_(min=1e-8) + + stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius) + avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0) + + # back to depth + stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle) + + canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze() + else: + raise ValueError(f'bad {mode=}') + + confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze() + return canon, canon2, confs + + +def anchor_depth_offsets(canon_depth, pixels, subsample=8): + device = canon_depth.device + + # create a 2D grid of anchor 3D points + H1, W1 = canon_depth.shape + yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample] + H2, W2 = yx.shape[1:] + cy, cx = yx.reshape(2, -1) + core_depth = canon_depth[cy, cx] + assert (core_depth > 0).all() + + # slave 3d points (attached to core 3d points) + core_idxs = {} # core_idxs[img2] = {corr_idx:core_idx} + core_offs = {} # core_offs[img2] = {corr_idx:3d_offset} + + for img2, (xy1, _confs) in pixels.items(): + px, py = xy1.long().T + + # find nearest anchor == block quantization + core_idx = (py // subsample) * W2 + (px // subsample) + core_idxs[img2] = core_idx.to(device) + + # compute relative depth offsets w.r.t. anchors + ref_z = core_depth[core_idx] + pts_z = canon_depth[py, px] + offset = pts_z / ref_z + core_offs[img2] = offset.detach().to(device) + + return core_idxs, core_offs + + +def spectral_clustering(graph, k=None, normalized_cuts=False): + graph.fill_diagonal_(0) + + # graph laplacian + degrees = graph.sum(dim=-1) + laplacian = torch.diag(degrees) - graph + if normalized_cuts: + i_inv = torch.diag(degrees.sqrt().reciprocal()) + laplacian = i_inv @ laplacian @ i_inv + + # compute eigenvectors! + eigval, eigvec = torch.linalg.eigh(laplacian) + return eigval[:k], eigvec[:, :k] + + +def sim_func(p1, p2, gamma): + diff = (p1 - p2).norm(dim=-1) + avg_depth = (p1[:, :, 2] + p2[:, :, 2]) + rel_distance = diff / avg_depth + sim = torch.exp(-gamma * rel_distance.square()) + return sim + + +def backproj(K, depthmap, subsample): + H, W = depthmap.shape + uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2) + xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3) + return xyz + + +def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='', + normalized_cuts=True, gamma=7, min_norm=5): + try: + if cache_path: + cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth' + lora_proj = torch.load(cache_path, map_location=K.device) + + except IOError: + # reconstruct 3d points in camera coordinates + xyz = backproj(K, depthmap, subsample) + + # compute all distances + xyz = xyz.reshape(-1, 3) + graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma) + _, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts) + + if cache_path: + torch.save(lora_proj.cpu(), mkdir_for(cache_path)) + + lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm) + + # depthmap ~= lora_proj @ coeffs + return coeffs, lora_proj + + +def lora_encode_normed(lora_proj, x, min_norm, global_norm=False): + # encode the pointmap + coeffs = torch.linalg.pinv(lora_proj) @ x + + # rectify the norm of basis vector to be ~ equal + if coeffs.ndim == 1: + coeffs = coeffs[:, None] + if global_norm: + lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1] + elif min_norm: + lora_proj *= coeffs.norm(dim=1).clip(min=min_norm) + # can have rounding errors here! + coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float() + + return lora_proj.detach(), coeffs.detach() + + +@torch.no_grad() +def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw): + # recover 3d points + core_depth = [] + lora_proj = [] + + for i, img in enumerate(tqdm(imgs)): + cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None + depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample, + cache_path=cache, **kw) + core_depth.append(depth) + lora_proj.append(proj) + + return core_depth, lora_proj + + +def reproj2d(Trf, pts3d): + res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3] + clipped_z = res[:, 2:3].clip(min=1e-3) # make sure we don't have nans! + uv = res[:, 0:2] / clipped_z + return uv.clip(min=-1000, max=2000) + + +def bfs(tree, start_node): + order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False) + ranks = np.arange(len(order)) + ranks[order] = ranks.copy() + return ranks, predecessors + + +def compute_min_spanning_tree(pws): + sparse_graph = sp.dok_array(pws.shape) + for i, j in pws.nonzero().cpu().tolist(): + sparse_graph[i, j] = -float(pws[i, j]) + msp = sp.csgraph.minimum_spanning_tree(sparse_graph) + + # now reorder the oriented edges, starting from the central point + ranks1, _ = bfs(msp, 0) + ranks2, _ = bfs(msp, ranks1.argmax()) + ranks1, _ = bfs(msp, ranks2.argmax()) + # this is the point farther from any leaf + root = np.minimum(ranks1, ranks2).argmax() + + # find the ordered list of edges that describe the tree + order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False) + order = order[1:] # root not do not have a predecessor + edges = [(predecessors[i], i) for i in order] + + return root, edges + + +def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw): + viz = SceneViz() + + cc = cam2w[:, :3, 3] + cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median()) + colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3)) + + if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2: + cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs) + else: + imgs = shapes_or_imgs + cam_kws = dict(images=imgs, cam_size=cs) + if K is not None: + viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws) + + if gt_cam2w is not None: + if gt_K is None: + gt_K = K + viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws) + + if pts3d is not None: + for i, p in enumerate(pts3d): + if not len(p): + continue + if masks is None: + viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist())) + else: + viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i]) + viz.show(**kw) diff --git a/imcui/third_party/mast3r/mast3r/cloud_opt/triangulation.py b/imcui/third_party/mast3r/mast3r/cloud_opt/triangulation.py new file mode 100644 index 0000000000000000000000000000000000000000..2af88df37bfd360161b4e96b93b0fd28a0ecf183 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/cloud_opt/triangulation.py @@ -0,0 +1,80 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Matches Triangulation Utils +# -------------------------------------------------------- + +import numpy as np +import torch + +# Batched Matches Triangulation +def batched_triangulate(pts2d, # [B, Ncams, Npts, 2] + proj_mats): # [B, Ncams, 3, 4] I@E projection matrix + B, Ncams, Npts, two = pts2d.shape + assert two==2 + assert proj_mats.shape == (B, Ncams, 3, 4) + # P - xP + x = proj_mats[...,0,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,0], proj_mats[...,2,:]) # [B, Ncams, Npts, 4] + y = proj_mats[...,1,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,1], proj_mats[...,2,:]) # [B, Ncams, Npts, 4] + eq = torch.cat([x, y], dim=1).transpose(1, 2) # [B, Npts, 2xNcams, 4] + return torch.linalg.lstsq(eq[...,:3], -eq[...,3]).solution + +def matches_to_depths(intrinsics, # input camera intrinsics [B, Ncams, 3, 3] + extrinsics, # input camera extrinsics [B, Ncams, 3, 4] + matches, # input correspondences [B, Ncams, Npts, 2] + batchsize=16, # bs for batched processing + min_num_valids_ratio=.3 # at least this ratio of image pairs need to predict a match for a given pixel of img1 + ): + B, Nv, H, W, five = matches.shape + min_num_valids = np.floor(Nv*min_num_valids_ratio) + out_aggregated_points, out_depths, out_confs = [], [], [] + for b in range(B//batchsize+1): # batched processing + start, stop = b*batchsize,min(B,(b+1)*batchsize) + sub_batch=slice(start,stop) + sub_batchsize = stop-start + if sub_batchsize==0:continue + points1, points2, confs = matches[sub_batch, ..., :2], matches[sub_batch, ..., 2:4], matches[sub_batch, ..., -1] + allpoints = torch.cat([points1.view([sub_batchsize*Nv,1,H*W,2]), points2.view([sub_batchsize*Nv,1,H*W,2])],dim=1) # [BxNv, 2, HxW, 2] + + allcam_Ps = intrinsics[sub_batch] @ extrinsics[sub_batch,:,:3,:] + cam_Ps1, cam_Ps2 = allcam_Ps[:,[0]].repeat([1,Nv,1,1]), allcam_Ps[:,1:] # [B, Nv, 3, 4] + formatted_camPs = torch.cat([cam_Ps1.reshape([sub_batchsize*Nv,1,3,4]), cam_Ps2.reshape([sub_batchsize*Nv,1,3,4])],dim=1) # [BxNv, 2, 3, 4] + + # Triangulate matches to 3D + points_3d_world = batched_triangulate(allpoints, formatted_camPs) # [BxNv, HxW, three] + + # Aggregate pairwise predictions + points_3d_world = points_3d_world.view([sub_batchsize,Nv,H,W,3]) + valids = points_3d_world.isfinite() + valids_sum = valids.sum(dim=-1) + validsuni=valids_sum.unique() + assert torch.all(torch.logical_or(validsuni == 0 , validsuni == 3)), "Error, can only be nan for none or all XYZ values, not a subset" + confs[valids_sum==0] = 0. + points_3d_world = points_3d_world*confs[...,None] + + # Take care of NaNs + normalization = confs.sum(dim=1)[:,None].repeat(1,Nv,1,1) + normalization[normalization <= 1e-5] = 1. + points_3d_world[valids] /= normalization[valids_sum==3][:,None].repeat(1,3).view(-1) + points_3d_world[~valids] = 0. + aggregated_points = points_3d_world.sum(dim=1) # weighted average (by confidence value) ignoring nans + + # Reset invalid values to nans, with a min visibility threshold + aggregated_points[valids_sum.sum(dim=1)/3 <= min_num_valids] = torch.nan + + # From 3D to depths + refcamE = extrinsics[sub_batch, 0] + points_3d_camera = (refcamE[:,:3, :3] @ aggregated_points.view(sub_batchsize,-1,3).transpose(-2,-1) + refcamE[:,:3,[3]]).transpose(-2,-1) # [B,HxW,3] + depths = points_3d_camera.view(sub_batchsize,H,W,3)[..., 2] # [B,H,W] + + # Cat results + out_aggregated_points.append(aggregated_points.cpu()) + out_depths.append(depths.cpu()) + out_confs.append(confs.sum(dim=1).cpu()) + + out_aggregated_points = torch.cat(out_aggregated_points,dim=0) + out_depths = torch.cat(out_depths,dim=0) + out_confs = torch.cat(out_confs,dim=0) + + return out_aggregated_points, out_depths, out_confs diff --git a/imcui/third_party/mast3r/mast3r/cloud_opt/tsdf_optimizer.py b/imcui/third_party/mast3r/mast3r/cloud_opt/tsdf_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..69f138c0301e4ad3cd4804d265f241b923e1b2b8 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/cloud_opt/tsdf_optimizer.py @@ -0,0 +1,273 @@ +import torch +from torch import nn +import numpy as np +from tqdm import tqdm +from matplotlib import pyplot as pl + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv +from dust3r.cloud_opt.base_opt import clean_pointcloud + + +class TSDFPostProcess: + """ Optimizes a signed distance-function to improve depthmaps. + """ + + def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)): + self.TSDF_thresh = TSDF_thresh # None -> no TSDF + self.TSDF_batchsize = TSDF_batchsize + self.optimizer = optimizer + + pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample) + pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs) + self.pts3d = pts3d + self.depthmaps = depthmaps + self.confs = confs + + def _get_depthmaps(self, TSDF_filtering_thresh=None): + if TSDF_filtering_thresh: + self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh) # compute refined depths if needed + dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps + return [d.exp() for d in dms] + + @torch.no_grad() + def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000): + """ + Leverage TSDF to post-process estimated depths + for each pixel, find zero level of TSDF along ray (or closest to 0) + """ + print("Post-Processing Depths with TSDF fusion.") + self.TSDF_im_depthmaps = [] + alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses( + ), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes + for vi in tqdm(range(self.optimizer.n_imgs)): + dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi] + minvals = torch.full(dm.shape, 1e20) + + for it in range(niter): + H, W = dm.shape + curthresh = (niter - it) * TSDF_filtering_thresh + dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \ + curthresh # decreasing search std along with iterations + newdm = dm[..., None] + dm_offsets # [H,W,Nsamp] + curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[ + imshape])[0] # [H,W,Nsamp,3] + # Batched TSDF eval + curproj = curproj.view(-1, 3) + tsdf_vals = [] + valids = [] + for batch in range(0, len(curproj), self.TSDF_batchsize): + values, valid = self._TSDF_query( + curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh) + tsdf_vals.append(values) + valids.append(valid) + tsdf_vals = torch.cat(tsdf_vals, dim=0) + valids = torch.cat(valids, dim=0) + + tsdf_vals = tsdf_vals.view([H, W, nsamples]) + valids = valids.view([H, W, nsamples]) + + # keep depth value that got us the closest to 0 + tsdf_vals[~valids] = torch.inf # ignore invalid values + tsdf_vals = tsdf_vals.abs() + mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True) + # when all samples live on a very flat zone, do nothing + allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples + dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad] + + # Save refined depth map + self.TSDF_im_depthmaps.append(dm.log()) + + def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True): + """ + TSDF query call: returns the weighted TSDF value for each query point [N, 3] + """ + N, three = qpoints.shape + assert three == 3 + qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1) # [B,N,3] + # get projection coordinates and depths onto images + coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses( + ), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points()) + image_coords = coords_and_depth[..., :2].round().to(int) # for now, there's no interpolation... + proj_depths = coords_and_depth[..., -1] + # recover depth values after scene optim + pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords) + # Gather TSDF scores + all_SDF_scores = pred_depths - proj_depths # SDF + unseen = all_SDF_scores < -TSDF_filtering_thresh # handle visibility + # all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF + all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20) # SDF -> TSDF + # Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth + all_TSDF_weights = (~unseen).float() * valids.float() + if weighted: + all_TSDF_weights = pred_confs.exp() * all_TSDF_weights + # Aggregate all votes, ignoring zeros + TSDF_weights = all_TSDF_weights.sum(dim=0) + valids = TSDF_weights != 0. + TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0) + TSDF_wsum[valids] /= TSDF_weights[valids] + return TSDF_wsum, valids + + def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False): + """ Recover depth value for each input pixel coordinate, along with OOB validity mask + """ + B, N, two = image_coords.shape + assert B == self.optimizer.n_imgs and two == 2 + depths = torch.zeros([B, N], device=image_coords.device) + valids = torch.zeros([B, N], dtype=bool, device=image_coords.device) + confs = torch.zeros([B, N], device=image_coords.device) + curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf + for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)): + H, W = depth.shape + valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] < + H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W) + imc[~valids[ni]] = 0 + depths[ni] = depth[imc[:, 1], imc[:, 0]] + confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]] + return depths, confs, valids + + def _get_confs_with_normals(self): + outconfs = [] + # Confidence basedf on depth gradient + + class Sobel(nn.Module): + def __init__(self): + super().__init__() + self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False) + Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]]) + Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]]) + G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0) + G = G.unsqueeze(1) + self.filter.weight = nn.Parameter(G, requires_grad=False) + + def forward(self, img): + x = self.filter(img) + x = torch.mul(x, x) + x = torch.sum(x, dim=1, keepdim=True) + x = torch.sqrt(x) + return x + + grad_op = Sobel().to(self.im_depthmaps[0].device) + for conf, depth in zip(self.im_conf, self.im_depthmaps): + grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0) + if not 'dbg show': + pl.imshow(grad_confs.cpu()) + pl.show() + outconfs.append(conf * grad_confs.to(conf)) + return outconfs + + def _proj_pts3d(self, pts3d, cam2worlds, focals, pps): + """ + Projection operation: from 3D points to 2D coordinates + depths + """ + B = pts3d.shape[0] + assert pts3d.shape[0] == cam2worlds.shape[0] + # prepare Extrinsincs + R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1] + Rinv = R.transpose(-2, -1) + tinv = -Rinv @ t[..., None] + + # prepare intrinsics + intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1) + if len(focals.shape) == 1: + focals = torch.stack([focals, focals], dim=-1) + intrinsics[:, 0, 0] = focals[:, 0] + intrinsics[:, 1, 1] = focals[:, 1] + intrinsics[:, :2, -1] = pps + # Project + projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N] + projpts = projpts.transpose(-2, -1) # [B,N,3] + projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z) + return projpts + + def _backproj_pts3d(self, in_depths=None, in_im_poses=None, + in_focals=None, in_pps=None, in_imshapes=None): + """ + Backprojection operation: from image depths to 3D points + """ + # Get depths and projection params if not provided + focals = self.optimizer.get_focals() if in_focals is None else in_focals + im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses + depth = self._get_depthmaps() if in_depths is None else in_depths + pp = self.optimizer.get_principal_points() if in_pps is None else in_pps + imshapes = self.imshapes if in_imshapes is None else in_imshapes + def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i]) + dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])] + + def autoprocess(x): + x = x[0] + return x.transpose(-2, -1) if len(x.shape) == 4 else x + return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)] + + def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps): + """ + Projection operation: from 3D points to 2D coordinates + depths + """ + B = pts3d.shape[0] + assert pts3d.shape[0] == cam2worlds.shape[0] + # prepare Extrinsincs + R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1] + Rinv = R.transpose(-2, -1) + tinv = -Rinv @ t[..., None] + + # prepare intrinsics + intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1) + if len(focals.shape) == 1: + focals = torch.stack([focals, focals], dim=-1) + intrinsics[:, 0, 0] = focals[:, 0] + intrinsics[:, 1, 1] = focals[:, 1] + intrinsics[:, :2, -1] = pps + # Project + projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N] + projpts = projpts.transpose(-2, -1) # [B,N,3] + projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z) + return projpts + + def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None): + """ + Backprojection operation: from image depths to 3D points + """ + # Get depths and projection params if not provided + focals = self.optimizer.get_focals() if in_focals is None else in_focals + im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses + depth = self._get_depthmaps() if in_depths is None else in_depths + pp = self.optimizer.get_principal_points() if in_pps is None else in_pps + imshapes = self.imshapes if in_imshapes is None else in_imshapes + + def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i]) + + dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])] + + def autoprocess(x): + x = x[0] + H, W, three = x.shape[:3] + return x.transpose(-2, -1) if len(x.shape) == 4 else x + return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)] + + def _get_pts3d(self, TSDF_filtering_thresh=None, **kw): + """ + return 3D points (possibly filtering depths with TSDF) + """ + return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw) + + def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1): + # Setup inner variables + self.imshapes = [im.shape[:2] for im in self.optimizer.imgs] + self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)] + self.im_conf = confs + + if self.TSDF_thresh > 0.: + # Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF + self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter) + depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps] + # Turn them into 3D points + pts3d = self._backproj_pts3d(in_depths=depthmaps) + depthmaps = [dd.flatten() for dd in depthmaps] + pts3d = [pp.view(-1, 3) for pp in pts3d] + return pts3d, depthmaps + + def get_dense_pts3d(self, clean_depth=True): + if clean_depth: + confs = clean_pointcloud(self.confs, self.optimizer.intrinsics, inv(self.optimizer.cam2w), + self.depthmaps, self.pts3d) + return self.pts3d, self.depthmaps, confs diff --git a/imcui/third_party/mast3r/mast3r/cloud_opt/utils/__init__.py b/imcui/third_party/mast3r/mast3r/cloud_opt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dd877d649ce4dbd749dd7195a8b34c0f91d4f0 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/cloud_opt/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). \ No newline at end of file diff --git a/imcui/third_party/mast3r/mast3r/cloud_opt/utils/losses.py b/imcui/third_party/mast3r/mast3r/cloud_opt/utils/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e1dd36afd6862592b8d00c499988136a972bd6e6 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/cloud_opt/utils/losses.py @@ -0,0 +1,32 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# losses for sparse ga +# -------------------------------------------------------- +import torch +import numpy as np + + +def l05_loss(x, y): + return torch.linalg.norm(x - y, dim=-1).sqrt() + + +def l1_loss(x, y): + return torch.linalg.norm(x - y, dim=-1) + + +def gamma_loss(gamma, mul=1, offset=None, clip=np.inf): + if offset is None: + if gamma == 1: + return l1_loss + # d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1)) + offset = (1 / gamma)**(1 / (gamma - 1)) + + def loss_func(x, y): + return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma + return loss_func + + +def meta_gamma_loss(): + return lambda alpha: gamma_loss(alpha) diff --git a/imcui/third_party/mast3r/mast3r/cloud_opt/utils/schedules.py b/imcui/third_party/mast3r/mast3r/cloud_opt/utils/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..d96253b4348d2f089c10142c5991e5afb8a9b683 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/cloud_opt/utils/schedules.py @@ -0,0 +1,17 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# lr schedules for sparse ga +# -------------------------------------------------------- +import numpy as np + + +def linear_schedule(alpha, lr_base, lr_end=0): + lr = (1 - alpha) * lr_base + alpha * lr_end + return lr + + +def cosine_schedule(alpha, lr_base, lr_end=0): + lr = lr_end + (lr_base - lr_end) * (1 + np.cos(alpha * np.pi)) / 2 + return lr diff --git a/imcui/third_party/mast3r/mast3r/colmap/__init__.py b/imcui/third_party/mast3r/mast3r/colmap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dd877d649ce4dbd749dd7195a8b34c0f91d4f0 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/colmap/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). \ No newline at end of file diff --git a/imcui/third_party/mast3r/mast3r/colmap/database.py b/imcui/third_party/mast3r/mast3r/colmap/database.py new file mode 100644 index 0000000000000000000000000000000000000000..5de83a35664d4038a99713de7f397e83940e5421 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/colmap/database.py @@ -0,0 +1,383 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# MASt3R to colmap export functions +# -------------------------------------------------------- +import os +import torch +import copy +import numpy as np +import torchvision +import numpy as np +from tqdm import tqdm +from scipy.cluster.hierarchy import DisjointSet +from scipy.spatial.transform import Rotation as R + +from mast3r.utils.misc import hash_md5 + +from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf # noqa + + +def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz): + if viz: + from matplotlib import pyplot as pl + + image_mean = torch.as_tensor( + [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) + image_std = torch.as_tensor( + [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) + rgb0 = img0['img'] * image_std + image_mean + rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0]) + rgb0 = np.array(rgb0) + + rgb1 = img1['img'] * image_std + image_mean + rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0]) + rgb1 = np.array(rgb1) + + imgs = [rgb0, rgb1] + # visualize a few matches + n_viz = 100 + num_matches = matches_im0.shape[0] + match_idx_to_viz = np.round(np.linspace( + 0, num_matches - 1, n_viz)).astype(int) + viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] + + H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] + rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), + (0, 0), (0, 0)), 'constant', constant_values=0) + rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), + (0, 0), (0, 0)), 'constant', constant_values=0) + img = np.concatenate((rgb0, rgb1), axis=1) + pl.figure() + pl.imshow(img) + cmap = pl.get_cmap('jet') + for ii in range(n_viz): + (x0, y0), (x1, + y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T + pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii / + (n_viz - 1)), scalex=False, scaley=False) + pl.show(block=True) + + matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)] + imgs = [img0, img1] + imidx0 = img0['idx'] + imidx1 = img1['idx'] + ravel_matches = [] + for j in range(2): + H, W = imgs[j]['true_shape'][0] + with np.errstate(invalid='ignore'): + qx, qy = matches[j].round().astype(np.int32).T + ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy) + ravel_matches.append(ravel_matches_j) + imidxj = imgs[j]['idx'] + for m in ravel_matches_j: + if m not in im_keypoints[imidxj]: + im_keypoints[imidxj][m] = 0 + im_keypoints[imidxj][m] += 1 + imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid']) + imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid']) + if imid0 > imid1: + colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1) + imid0, imid1 = imid1, imid0 + imidx0, imidx1 = imidx1, imidx0 + else: + colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1) + colmap_matches = np.unique(colmap_matches, axis=0) + return imidx0, imidx1, colmap_matches + + +def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr, + is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'): + im_matches = {} + for i in range(len(pred1['pts3d'])): + imidx0 = pairs[i][0]['idx'] + imidx1 = pairs[i][1]['idx'] + if 'desc' in pred1: # mast3r + descs = [pred1['desc'][i], pred2['desc'][i]] + confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]] + desc_dim = descs[0].shape[-1] + + if is_sparse: + corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1], + device=device, subsample=subsample, pixel_tol=pixel_tol) + conf = corres[2] + mask = conf >= conf_thr + matches_im0 = corres[0][mask].cpu().numpy() + matches_im1 = corres[1][mask].cpu().numpy() + else: + confidence_masks = [confidences[0] >= + conf_thr, confidences[1] >= conf_thr] + pts2d_list, desc_list = [], [] + for j in range(2): + conf_j = confidence_masks[j].cpu().numpy().flatten() + true_shape_j = pairs[i][j]['true_shape'][0] + pts2d_j = xy_grid( + true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] + desc_j = descs[j].detach().cpu( + ).numpy().reshape(-1, desc_dim)[conf_j] + pts2d_list.append(pts2d_j) + desc_list.append(desc_j) + if len(desc_list[0]) == 0 or len(desc_list[1]) == 0: + continue + + nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1], + device=device, dist='dot', block_size=2**13) + reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0))) + + matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0] + matches_im0 = pts2d_list[0][reciprocal_in_P0] + else: + pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]] + confidences = [pred1['conf'][i], pred2['conf'][i]] + + if is_sparse: + corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1], + device=device, subsample=subsample, pixel_tol=pixel_tol, + ptmap_key='3d') + conf = corres[2] + mask = conf >= conf_thr + matches_im0 = corres[0][mask].cpu().numpy() + matches_im1 = corres[1][mask].cpu().numpy() + else: + confidence_masks = [confidences[0] >= + conf_thr, confidences[1] >= conf_thr] + # find 2D-2D matches between the two images + pts2d_list, pts3d_list = [], [] + for j in range(2): + conf_j = confidence_masks[j].cpu().numpy().flatten() + true_shape_j = pairs[i][j]['true_shape'][0] + pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] + pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j] + pts2d_list.append(pts2d_j) + pts3d_list.append(pts3d_j) + + PQ, PM = pts3d_list[0], pts3d_list[1] + if len(PQ) == 0 or len(PM) == 0: + continue + reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches( + PQ, PM) + + matches_im1 = pts2d_list[1][reciprocal_in_PM] + matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] + + if len(matches_im0) == 0: + continue + imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], + image_to_colmap, im_keypoints, + matches_im0, matches_im1, viz) + im_matches[(imidx0, imidx1)] = colmap_matches + return im_matches + + +def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample, + image_to_colmap, im_keypoints, conf_thr, + viz=False, device='cuda'): + im_matches = {} + for i in range(len(pairs)): + imidx0 = pairs[i][0]['idx'] + imidx1 = pairs[i][1]['idx'] + + corres_idx1 = hash_md5(pairs[i][0]['instance']) + corres_idx2 = hash_md5(pairs[i][1]['instance']) + + path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth' + if os.path.isfile(path_corres): + score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device) + else: + path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth' + score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device) + mask = confs >= conf_thr + matches_im0 = xy1[mask].cpu().numpy() + matches_im1 = xy2[mask].cpu().numpy() + + if len(matches_im0) == 0: + continue + imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], + image_to_colmap, im_keypoints, + matches_im0, matches_im1, viz) + im_matches[(imidx0, imidx1)] = colmap_matches + return im_matches + + +def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model): + # add cameras/images to the db + # with the output of ga as prior + image_to_colmap = {} + im_keypoints = {} + for idx in range(len(image_paths)): + im_keypoints[idx] = {} + H, W = images[idx]["orig_shape"] + if focals is None: + focal_x = focal_y = 1.2 * max(W, H) + prior_focal_length = False + cx = W / 2.0 + cy = H / 2.0 + elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2: + # intrinsics + focal_x = focals[idx][0, 0] + focal_y = focals[idx][1, 1] + cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0] + cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1] + prior_focal_length = True + else: + focal_x = focal_y = float(focals[idx]) + prior_focal_length = True + cx = W / 2.0 + cy = H / 2.0 + focal_x = focal_x * images[idx]["to_orig"][0, 0] + focal_y = focal_y * images[idx]["to_orig"][1, 1] + + if camera_model == "SIMPLE_PINHOLE": + model_id = 0 + focal = (focal_x + focal_y) / 2.0 + params = np.asarray([focal, cx, cy], np.float64) + elif camera_model == "PINHOLE": + model_id = 1 + params = np.asarray([focal_x, focal_y, cx, cy], np.float64) + elif camera_model == "SIMPLE_RADIAL": + model_id = 2 + focal = (focal_x + focal_y) / 2.0 + params = np.asarray([focal, cx, cy, 0.0], np.float64) + elif camera_model == "OPENCV": + model_id = 4 + params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64) + else: + raise ValueError(f"invalid camera model {camera_model}") + + H, W = int(H), int(W) + # OPENCV camera model + camid = db.add_camera( + model_id, W, H, params, prior_focal_length=prior_focal_length) + if ga_world_to_cam is None: + prior_t = np.zeros(3) + prior_q = np.zeros(4) + else: + q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat() + prior_t = ga_world_to_cam[idx][:3, 3] + prior_q = np.array([q[-1], q[0], q[1], q[2]]) + imid = db.add_image( + image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t) + image_to_colmap[idx] = { + 'colmap_imid': imid, + 'colmap_camid': camid + } + return image_to_colmap, im_keypoints + + +def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification): + colmap_image_pairs = [] + # 2D-2D are quite dense + # we want to remove the very small tracks + # and export only kpt for which we have values + # build tracks + print("building tracks") + keypoints_to_track_id = {} + track_id_to_kpt_list = [] + to_merge = [] + for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()): + if imidx0 not in keypoints_to_track_id: + keypoints_to_track_id[imidx0] = {} + if imidx1 not in keypoints_to_track_id: + keypoints_to_track_id[imidx1] = {} + + for m in colmap_matches: + if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]: + # new pair of kpts never seen before + track_idx = len(track_id_to_kpt_list) + keypoints_to_track_id[imidx0][m[0]] = track_idx + keypoints_to_track_id[imidx1][m[1]] = track_idx + track_id_to_kpt_list.append( + [(imidx0, m[0]), (imidx1, m[1])]) + elif m[1] not in keypoints_to_track_id[imidx1]: + # 0 has a track, not 1 + track_idx = keypoints_to_track_id[imidx0][m[0]] + keypoints_to_track_id[imidx1][m[1]] = track_idx + track_id_to_kpt_list[track_idx].append((imidx1, m[1])) + elif m[0] not in keypoints_to_track_id[imidx0]: + # 1 has a track, not 0 + track_idx = keypoints_to_track_id[imidx1][m[1]] + keypoints_to_track_id[imidx0][m[0]] = track_idx + track_id_to_kpt_list[track_idx].append((imidx0, m[0])) + else: + # both have tracks, merge them + track_idx0 = keypoints_to_track_id[imidx0][m[0]] + track_idx1 = keypoints_to_track_id[imidx1][m[1]] + if track_idx0 != track_idx1: + # let's deal with them later + to_merge.append((track_idx0, track_idx1)) + + # regroup merge targets + print("merging tracks") + unique = np.unique(to_merge) + tree = DisjointSet(unique) + for track_idx0, track_idx1 in tqdm(to_merge): + tree.merge(track_idx0, track_idx1) + + subsets = tree.subsets() + print("applying merge") + for setvals in tqdm(subsets): + new_trackid = len(track_id_to_kpt_list) + kpt_list = [] + for track_idx in setvals: + kpt_list.extend(track_id_to_kpt_list[track_idx]) + for imidx, kpid in track_id_to_kpt_list[track_idx]: + keypoints_to_track_id[imidx][kpid] = new_trackid + track_id_to_kpt_list.append(kpt_list) + + # binc = np.bincount([len(v) for v in track_id_to_kpt_list]) + # nonzero = np.nonzero(binc) + # nonzerobinc = binc[nonzero[0]] + # print(nonzero[0].tolist()) + # print(nonzerobinc) + num_valid_tracks = sum( + [1 for v in track_id_to_kpt_list if len(v) >= min_len_track]) + + keypoints_to_idx = {} + print(f"squashing keypoints - {num_valid_tracks} valid tracks") + for imidx, keypoints_imid in tqdm(im_keypoints.items()): + imid = image_to_colmap[imidx]['colmap_imid'] + keypoints_kept = [] + keypoints_to_idx[imidx] = {} + for kp in keypoints_imid.keys(): + if kp not in keypoints_to_track_id[imidx]: + continue + track_idx = keypoints_to_track_id[imidx][kp] + track_length = len(track_id_to_kpt_list[track_idx]) + if track_length < min_len_track: + continue + keypoints_to_idx[imidx][kp] = len(keypoints_kept) + keypoints_kept.append(kp) + if len(keypoints_kept) == 0: + continue + keypoints_kept = np.array(keypoints_kept) + keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[ + 0].base[:, ::-1].copy().astype(np.float32) + # rescale coordinates + keypoints_kept[:, 0] += 0.5 + keypoints_kept[:, 1] += 0.5 + keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True) + + H, W = images[imidx]['orig_shape'] + keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01) + keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01) + + db.add_keypoints(imid, keypoints_kept) + + print("exporting im_matches") + for (imidx0, imidx1), colmap_matches in im_matches.items(): + imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid'] + assert imid0 < imid1 + final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]] + for m in colmap_matches + if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]]) + if len(final_matches) > 0: + colmap_image_pairs.append( + (images[imidx0]['instance'], images[imidx1]['instance'])) + db.add_matches(imid0, imid1, final_matches) + if skip_geometric_verification: + db.add_two_view_geometry(imid0, imid1, final_matches) + return colmap_image_pairs diff --git a/imcui/third_party/mast3r/mast3r/datasets/__init__.py b/imcui/third_party/mast3r/mast3r/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c625aca0a773c105ed229ff87364721b4755bc8d --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/datasets/__init__.py @@ -0,0 +1,62 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from .base.mast3r_base_stereo_view_dataset import MASt3RBaseStereoViewDataset + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.datasets.arkitscenes import ARKitScenes as DUSt3R_ARKitScenes # noqa +from dust3r.datasets.blendedmvs import BlendedMVS as DUSt3R_BlendedMVS # noqa +from dust3r.datasets.co3d import Co3d as DUSt3R_Co3d # noqa +from dust3r.datasets.megadepth import MegaDepth as DUSt3R_MegaDepth # noqa +from dust3r.datasets.scannetpp import ScanNetpp as DUSt3R_ScanNetpp # noqa +from dust3r.datasets.staticthings3d import StaticThings3D as DUSt3R_StaticThings3D # noqa +from dust3r.datasets.waymo import Waymo as DUSt3R_Waymo # noqa +from dust3r.datasets.wildrgbd import WildRGBD as DUSt3R_WildRGBD # noqa + + +class ARKitScenes(DUSt3R_ARKitScenes, MASt3RBaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + super().__init__(*args, split=split, ROOT=ROOT, **kwargs) + self.is_metric_scale = True + + +class BlendedMVS(DUSt3R_BlendedMVS, MASt3RBaseStereoViewDataset): + def __init__(self, *args, ROOT, split=None, **kwargs): + super().__init__(*args, ROOT=ROOT, split=split, **kwargs) + self.is_metric_scale = False + + +class Co3d(DUSt3R_Co3d, MASt3RBaseStereoViewDataset): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.is_metric_scale = False + + +class MegaDepth(DUSt3R_MegaDepth, MASt3RBaseStereoViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + super().__init__(*args, split=split, ROOT=ROOT, **kwargs) + self.is_metric_scale = False + + +class ScanNetpp(DUSt3R_ScanNetpp, MASt3RBaseStereoViewDataset): + def __init__(self, *args, ROOT, **kwargs): + super().__init__(*args, ROOT=ROOT, **kwargs) + self.is_metric_scale = True + + +class StaticThings3D(DUSt3R_StaticThings3D, MASt3RBaseStereoViewDataset): + def __init__(self, ROOT, *args, mask_bg='rand', **kwargs): + super().__init__(ROOT, *args, mask_bg=mask_bg, **kwargs) + self.is_metric_scale = False + + +class Waymo(DUSt3R_Waymo, MASt3RBaseStereoViewDataset): + def __init__(self, *args, ROOT, **kwargs): + super().__init__(*args, ROOT=ROOT, **kwargs) + self.is_metric_scale = True + + +class WildRGBD(DUSt3R_WildRGBD, MASt3RBaseStereoViewDataset): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.is_metric_scale = True diff --git a/imcui/third_party/mast3r/mast3r/datasets/base/__init__.py b/imcui/third_party/mast3r/mast3r/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dd877d649ce4dbd749dd7195a8b34c0f91d4f0 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/datasets/base/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). \ No newline at end of file diff --git a/imcui/third_party/mast3r/mast3r/datasets/base/mast3r_base_stereo_view_dataset.py b/imcui/third_party/mast3r/mast3r/datasets/base/mast3r_base_stereo_view_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3ced0ef0dc6b1d6225781af55d3e924e133fdeaf --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/datasets/base/mast3r_base_stereo_view_dataset.py @@ -0,0 +1,355 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# base class for implementing datasets +# -------------------------------------------------------- +import PIL.Image +import PIL.Image as Image +import numpy as np +import torch +import copy + +from mast3r.datasets.utils.cropping import (extract_correspondences_from_pts3d, + gen_random_crops, in2d_rect, crop_to_homography) + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset, view_name, is_good_type # noqa +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf, depthmap_to_camera_coordinates +import dust3r.datasets.utils.cropping as cropping + + +class MASt3RBaseStereoViewDataset(BaseStereoViewDataset): + def __init__(self, *, # only keyword arguments + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + aug_swap=False, + aug_monocular=False, + aug_portrait_or_landscape=True, # automatic choice between landscape/portrait when possible + aug_rot90=False, + n_corres=0, + nneg=0, + n_tentative_crops=4, + seed=None): + super().__init__(split=split, resolution=resolution, transform=transform, aug_crop=aug_crop, seed=seed) + self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this + + self.aug_swap = aug_swap + self.aug_monocular = aug_monocular + self.aug_portrait_or_landscape = aug_portrait_or_landscape + self.aug_rot90 = aug_rot90 + + self.n_corres = n_corres + self.nneg = nneg + assert self.n_corres == 'all' or isinstance(self.n_corres, int) or (isinstance(self.n_corres, list) and len( + self.n_corres) == self.num_views), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}" + assert self.nneg == 0 or self.n_corres != 'all' + self.n_tentative_crops = n_tentative_crops + + def _swap_view_aug(self, views): + if self._rng.random() < 0.5: + views.reverse() + + def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None): + """ This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # transpose the resolution if necessary + W, H = image.size # new size + assert resolution[0] >= resolution[1] + if H > 1.1 * W: + # image is portrait mode + resolution = resolution[::-1] + elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2) and self.aug_portrait_or_landscape: + resolution = resolution[::-1] + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) + + # actual cropping (if necessary) with bilinear interpolation + offset_factor = 0.5 + intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor) + crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + return image, depthmap, intrinsics2 + + def generate_crops_from_pair(self, view1, view2, resolution, aug_crop_arg, n_crops=4, rng=np.random): + views = [view1, view2] + + if aug_crop_arg is False: + # compatibility + for i in range(2): + view = views[i] + view['img'], view['depthmap'], view['camera_intrinsics'] = self._crop_resize_if_necessary(view['img'], + view['depthmap'], + view['camera_intrinsics'], + resolution, + rng=rng) + view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'], + view['camera_intrinsics'], + view['camera_pose']) + return + + # extract correspondences + corres = extract_correspondences_from_pts3d(*views, target_n_corres=None, rng=rng) + + # generate 4 random crops in each view + view_crops = [] + crops_resolution = [] + corres_msks = [] + for i in range(2): + + if aug_crop_arg == 'auto': + S = min(views[i]['img'].size) + R = min(resolution) + aug_crop = S * (S - R) // R + aug_crop = max(.1 * S, aug_crop) # for cropping: augment scale of at least 10%, and more if possible + else: + aug_crop = aug_crop_arg + + # tranpose the target resolution if necessary + assert resolution[0] >= resolution[1] + W, H = imsize = views[i]['img'].size + crop_resolution = resolution + if H > 1.1 * W: + # image is portrait mode + crop_resolution = resolution[::-1] + elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2): + crop_resolution = resolution[::-1] + + crops = gen_random_crops(imsize, n_crops, crop_resolution, aug_crop=aug_crop, rng=rng) + view_crops.append(crops) + crops_resolution.append(crop_resolution) + + # compute correspondences + corres_msks.append(in2d_rect(corres[i], crops)) + + # compute IoU for each + intersection = np.float32(corres_msks[0]).T @ np.float32(corres_msks[1]) + # select best pair of crops + best = np.unravel_index(intersection.argmax(), (n_crops, n_crops)) + crops = [view_crops[i][c] for i, c in enumerate(best)] + + # crop with the homography + for i in range(2): + view = views[i] + imsize, K_new, R, H = crop_to_homography(view['camera_intrinsics'], crops[i], crops_resolution[i]) + # imsize, K_new, H = upscale_homography(imsize, resolution, K_new, H) + + # update camera params + K_old = view['camera_intrinsics'] + view['camera_intrinsics'] = K_new + view['camera_pose'] = view['camera_pose'].copy() + view['camera_pose'][:3, :3] = view['camera_pose'][:3, :3] @ R + + # apply homography to image and depthmap + homo8 = (H / H[2, 2]).ravel().tolist()[:8] + view['img'] = view['img'].transform(imsize, Image.Transform.PERSPECTIVE, + homo8, + resample=Image.Resampling.BICUBIC) + + depthmap2 = depthmap_to_camera_coordinates(view['depthmap'], K_old)[0] @ R[:, 2] + view['depthmap'] = np.array(Image.fromarray(depthmap2).transform( + imsize, Image.Transform.PERSPECTIVE, homo8)) + + if 'track_labels' in view: + # convert from uint64 --> uint32, because PIL.Image cannot handle uint64 + mapping, track_labels = np.unique(view['track_labels'], return_inverse=True) + track_labels = track_labels.astype(np.uint32).reshape(view['track_labels'].shape) + + # homography transformation + res = np.array(Image.fromarray(track_labels).transform(imsize, Image.Transform.PERSPECTIVE, homo8)) + view['track_labels'] = mapping[res] # mapping back to uint64 + + # recompute 3d points from scratch + view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'], + view['camera_intrinsics'], + view['camera_pose']) + + def __getitem__(self, idx): + if isinstance(idx, tuple): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, '_rng'): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # over-loaded code + resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng) + assert len(views) == self.num_views + + for v, view in enumerate(views): + assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view['idx'] = (idx, ar_idx, v) + view['is_metric_scale'] = self.is_metric_scale + + assert 'camera_intrinsics' in view + if 'camera_pose' not in view: + view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}' + assert 'pts3d' not in view + assert 'valid_mask' not in view + assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}' + + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view['pts3d'] = pts3d + view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + self.generate_crops_from_pair(views[0], views[1], resolution=resolution, + aug_crop_arg=self.aug_crop, + n_crops=self.n_tentative_crops, + rng=self._rng) + for v, view in enumerate(views): + # encode the image + width, height = view['img'].size + view['true_shape'] = np.int32((height, width)) + view['img'] = self.transform(view['img']) + # Pixels for which depth is fundamentally undefined + view['sky_mask'] = (view['depthmap'] < 0) + + if self.aug_swap: + self._swap_view_aug(views) + + if self.aug_monocular: + if self._rng.random() < self.aug_monocular: + views = [copy.deepcopy(views[0]) for _ in range(len(views))] + + # automatic extraction of correspondences from pts3d + pose + if self.n_corres > 0 and ('corres' not in view): + corres1, corres2, valid = extract_correspondences_from_pts3d(*views, self.n_corres, + self._rng, nneg=self.nneg) + views[0]['corres'] = corres1 + views[1]['corres'] = corres2 + views[0]['valid_corres'] = valid + views[1]['valid_corres'] = valid + + if self.aug_rot90 is False: + pass + elif self.aug_rot90 == 'same': + rotate_90(views, k=self._rng.choice(4)) + elif self.aug_rot90 == 'diff': + rotate_90(views[:1], k=self._rng.choice(4)) + rotate_90(views[1:], k=self._rng.choice(4)) + else: + raise ValueError(f'Bad value for {self.aug_rot90=}') + + # check data-types metric_scale + for v, view in enumerate(views): + if 'corres' not in view: + view['corres'] = np.full((self.n_corres, 2), np.nan, dtype=np.float32) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view['camera_intrinsics'] + + # check shapes + assert view['depthmap'].shape == view['img'].shape[1:] + assert view['depthmap'].shape == view['pts3d'].shape[:2] + assert view['depthmap'].shape == view['valid_mask'].shape + + # last thing done! + for view in views: + # transpose to make sure all views are the same size + transpose_to_landscape(view) + # this allows to check whether the RNG is is the same state each time + view['rng'] = int.from_bytes(self._rng.bytes(4), 'big') + + return views + + +def transpose_to_landscape(view, revert=False): + height, width = view['true_shape'] + + if width < height: + if revert: + height, width = width, height + + # rectify portrait to landscape + assert view['img'].shape == (3, height, width) + view['img'] = view['img'].swapaxes(1, 2) + + assert view['valid_mask'].shape == (height, width) + view['valid_mask'] = view['valid_mask'].swapaxes(0, 1) + + assert view['sky_mask'].shape == (height, width) + view['sky_mask'] = view['sky_mask'].swapaxes(0, 1) + + assert view['depthmap'].shape == (height, width) + view['depthmap'] = view['depthmap'].swapaxes(0, 1) + + assert view['pts3d'].shape == (height, width, 3) + view['pts3d'] = view['pts3d'].swapaxes(0, 1) + + # transpose x and y pixels + view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]] + + # transpose correspondences x and y + view['corres'] = view['corres'][:, [1, 0]] + + +def rotate_90(views, k=1): + from scipy.spatial.transform import Rotation + # print('rotation =', k) + + RT = np.eye(4, dtype=np.float32) + RT[:3, :3] = Rotation.from_euler('z', 90 * k, degrees=True).as_matrix() + + for view in views: + view['img'] = torch.rot90(view['img'], k=k, dims=(-2, -1)) # WARNING!! dims=(-1,-2) != dims=(-2,-1) + view['depthmap'] = np.rot90(view['depthmap'], k=k).copy() + view['camera_pose'] = view['camera_pose'] @ RT + + RT2 = np.eye(3, dtype=np.float32) + RT2[:2, :2] = RT[:2, :2] * ((1, -1), (-1, 1)) + H, W = view['depthmap'].shape + if k % 4 == 0: + pass + elif k % 4 == 1: + # top-left (0,0) pixel becomes (0,H-1) + RT2[:2, 2] = (0, H - 1) + elif k % 4 == 2: + # top-left (0,0) pixel becomes (W-1,H-1) + RT2[:2, 2] = (W - 1, H - 1) + elif k % 4 == 3: + # top-left (0,0) pixel becomes (W-1,0) + RT2[:2, 2] = (W - 1, 0) + else: + raise ValueError(f'Bad value for {k=}') + + view['camera_intrinsics'][:2, 2] = geotrf(RT2, view['camera_intrinsics'][:2, 2]) + if k % 2 == 1: + K = view['camera_intrinsics'] + np.fill_diagonal(K, K.diagonal()[[1, 0, 2]]) + + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + view['pts3d'] = pts3d + view['valid_mask'] = np.rot90(view['valid_mask'], k=k).copy() + view['sky_mask'] = np.rot90(view['sky_mask'], k=k).copy() + + view['corres'] = geotrf(RT2, view['corres']).round().astype(view['corres'].dtype) + view['true_shape'] = np.int32((H, W)) diff --git a/imcui/third_party/mast3r/mast3r/datasets/utils/__init__.py b/imcui/third_party/mast3r/mast3r/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/datasets/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/mast3r/mast3r/datasets/utils/cropping.py b/imcui/third_party/mast3r/mast3r/datasets/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..57f4d84b019eaac9cf0c308a94f2cb8e2ec1a6ba --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/datasets/utils/cropping.py @@ -0,0 +1,219 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# cropping/match extraction +# -------------------------------------------------------- +import numpy as np +import mast3r.utils.path_to_dust3r # noqa +from dust3r.utils.device import to_numpy +from dust3r.utils.geometry import inv, geotrf + + +def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False): + is_reciprocal1 = (corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))) + pos1 = is_reciprocal1.nonzero()[0] + pos2 = corres_1_to_2[pos1] + if ret_recip: + return is_reciprocal1, pos1, pos2 + return pos1, pos2 + + +def extract_correspondences_from_pts3d(view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0): + view1, view2 = to_numpy((view1, view2)) + # project pixels from image1 --> 3d points --> image2 pixels + shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2) + shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1) + + # compute reciprocal correspondences: + # pos1 == valid pixels (correspondences) in image1 + is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, ret_recip=True) + is_reciprocal2 = (corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))) + + if target_n_corres is None: + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2 + + available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum()) + target_n_positives = int(target_n_corres * (1 - nneg)) + n_positives = min(len(pos1), target_n_positives) + n_negatives = min(target_n_corres - n_positives, available_negatives) + + if n_negatives + n_positives != target_n_corres: + # should be really rare => when there are not enough negatives + # in that case, break nneg and add a few more positives ? + n_positives = target_n_corres - n_negatives + assert n_positives <= len(pos1) + + assert n_positives <= len(pos1) + assert n_positives <= len(pos2) + assert n_negatives <= (~is_reciprocal1).sum() + assert n_negatives <= (~is_reciprocal2).sum() + assert n_positives + n_negatives == target_n_corres + + valid = np.ones(n_positives, dtype=bool) + if n_positives < len(pos1): + # random sub-sampling of valid correspondences + perm = rng.permutation(len(pos1))[:n_positives] + pos1 = pos1[perm] + pos2 = pos2[perm] + + if n_negatives > 0: + # add false correspondences if not enough + def norm(p): return p / p.sum() + pos1 = np.r_[pos1, rng.choice(shape1[0] * shape1[1], size=n_negatives, replace=False, p=norm(~is_reciprocal1))] + pos2 = np.r_[pos2, rng.choice(shape2[0] * shape2[1], size=n_negatives, replace=False, p=norm(~is_reciprocal2))] + valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)] + + # convert (x+W*y) back to 2d (x,y) coordinates + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2, valid + + +def reproject_view(pts3d, view2): + shape = view2['pts3d'].shape[:2] + return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape) + + +def reproject(pts3d, K, world2cam, shape): + H, W, THREE = pts3d.shape + assert THREE == 3 + + # reproject in camera2 space + with np.errstate(divide='ignore', invalid='ignore'): + pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2) + + # quantize to pixel positions + return (H, W), ravel_xy(pos, shape) + + +def ravel_xy(pos, shape): + H, W = shape + with np.errstate(invalid='ignore'): + qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T + quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy) + return quantized_pos + + +def unravel_xy(pos, shape): + # convert (x+W*y) back to 2d (x,y) coordinates + return np.unravel_index(pos, shape)[0].base[:, ::-1].copy() + + +def _rotation_origin_to_pt(target): + """ Align the origin (0,0,1) with the target point (x,y,1) in projective space. + Method: rotate z to put target on (x'+,0,1), then rotate on Y to get (0,0,1) and un-rotate z. + """ + from scipy.spatial.transform import Rotation + x, y = target + rot_z = np.arctan2(y, x) + rot_y = np.arctan(np.linalg.norm(target)) + R = Rotation.from_euler('ZYZ', [rot_z, rot_y, -rot_z]).as_matrix() + return R + + +def _dotmv(Trf, pts, ncol=None, norm=False): + assert Trf.ndim >= 2 + ncol = ncol or pts.shape[-1] + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def crop_to_homography(K, crop, target_size=None): + """ Given an image and its intrinsics, + we want to replicate a rectangular crop with an homography, + so that the principal point of the new 'crop' is centered. + """ + # build intrinsics for the crop + crop = np.round(crop) + crop_size = crop[2:] - crop[:2] + K2 = K.copy() # same focal + K2[:2, 2] = crop_size / 2 # new principal point is perfectly centered + + # find which corner is the most far-away from current principal point + # so that the final homography does not go over the image borders + corners = crop.reshape(-1, 2) + corner_idx = np.abs(corners - K[:2, 2]).argmax(0) + corner = corners[corner_idx, [0, 1]] + # align with the corresponding corner from the target view + corner2 = np.c_[[0, 0], crop_size][[0, 1], corner_idx] + + old_pt = _dotmv(np.linalg.inv(K), corner, norm=1) + new_pt = _dotmv(np.linalg.inv(K2), corner2, norm=1) + R = _rotation_origin_to_pt(old_pt) @ np.linalg.inv(_rotation_origin_to_pt(new_pt)) + + if target_size is not None: + imsize = target_size + target_size = np.asarray(target_size) + scaling = min(target_size / crop_size) + K2[:2] *= scaling + K2[:2, 2] = target_size / 2 + else: + imsize = tuple(np.int32(crop_size).tolist()) + + return imsize, K2, R, K @ R @ np.linalg.inv(K2) + + +def gen_random_crops(imsize, n_crops, resolution, aug_crop, rng=np.random): + """ Generate random crops of size=resolution, + for an input image upscaled to (imsize + randint(0 , aug_crop)) + """ + resolution_crop = np.array(resolution) * min(np.array(imsize) / resolution) + + # (virtually) upscale the input image + # scaling = rng.uniform(1, 1+(aug_crop+1)/min(imsize)) + scaling = np.exp(rng.uniform(0, np.log(1 + aug_crop / min(imsize)))) + imsize2 = np.int32(np.array(imsize) * scaling) + + # generate some random crops + topleft = rng.random((n_crops, 2)) * (imsize2 - resolution_crop) + crops = np.c_[topleft, topleft + resolution_crop] + # print(f"{scaling=}, {topleft=}") + # reduce the resolution to come back to original size + crops /= scaling + return crops + + +def in2d_rect(corres, crops): + # corres = (N,2) + # crops = (M,4) + # output = (N, M) + is_sup = (corres[:, None] >= crops[None, :, 0:2]) + is_inf = (corres[:, None] < crops[None, :, 2:4]) + return (is_sup & is_inf).all(axis=-1) diff --git a/imcui/third_party/mast3r/mast3r/demo.py b/imcui/third_party/mast3r/mast3r/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..22b6a66c24666776a7197844a0463d7821ed53ce --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/demo.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# sparse gradio demo functions +# -------------------------------------------------------- +import math +import gradio +import os +import numpy as np +import functools +import trimesh +import copy +from scipy.spatial.transform import Rotation +import tempfile +import shutil + +from mast3r.cloud_opt.sparse_ga import sparse_global_alignment +from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.image_pairs import make_pairs +from dust3r.utils.image import load_images +from dust3r.utils.device import to_numpy +from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes +from dust3r.demo import get_args_parser as dust3r_get_args_parser + +import matplotlib.pyplot as pl + + +class SparseGAState(): + def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None): + self.sparse_ga = sparse_ga + self.cache_dir = cache_dir + self.outfile_name = outfile_name + self.should_delete = should_delete + + def __del__(self): + if not self.should_delete: + return + if self.cache_dir is not None and os.path.isdir(self.cache_dir): + shutil.rmtree(self.cache_dir) + self.cache_dir = None + if self.outfile_name is not None and os.path.isfile(self.outfile_name): + os.remove(self.outfile_name) + self.outfile_name = None + + +def get_args_parser(): + parser = dust3r_get_args_parser() + parser.add_argument('--share', action='store_true') + parser.add_argument('--gradio_delete_cache', default=None, type=int, + help='age/frequency at which gradio removes the file. If >0, matching cache is purged') + + actions = parser._actions + for action in actions: + if action.dest == 'model_name': + action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"] + # change defaults + parser.prog = 'mast3r demo' + return parser + + +def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, + cam_color=None, as_pointcloud=False, + transparent_cams=False, silent=False): + assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) + pts3d = to_numpy(pts3d) + imgs = to_numpy(imgs) + focals = to_numpy(focals) + cams2world = to_numpy(cams2world) + + scene = trimesh.Scene() + + # full pointcloud + if as_pointcloud: + pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3) + col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3) + valid_msk = np.isfinite(pts.sum(axis=1)) + pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk]) + scene.add_geometry(pct) + else: + meshes = [] + for i in range(len(imgs)): + pts3d_i = pts3d[i].reshape(imgs[i].shape) + msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1)) + meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i)) + mesh = trimesh.Trimesh(**cat_meshes(meshes)) + scene.add_geometry(mesh) + + # add each camera + for i, pose_c2w in enumerate(cams2world): + if isinstance(cam_color, list): + camera_edge_color = cam_color[i] + else: + camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] + add_scene_cam(scene, pose_c2w, camera_edge_color, + None if transparent_cams else imgs[i], focals[i], + imsize=imgs[i].shape[1::-1], screen_width=cam_size) + + rot = np.eye(4) + rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() + scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) + if not silent: + print('(exporting 3D scene to', outfile, ')') + scene.export(file_obj=outfile) + return outfile + + +def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False, + clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0): + """ + extract 3D_model (glb file) from a reconstructed scene + """ + if scene_state is None: + return None + outfile = scene_state.outfile_name + if outfile is None: + return None + + # get optimized values from scene + scene = scene_state.sparse_ga + rgbimg = scene.imgs + focals = scene.get_focals().cpu() + cams2world = scene.get_im_poses().cpu() + + # 3D pointcloud from depthmap, poses and intrinsics + if TSDF_thresh > 0: + tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh) + pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth)) + else: + pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth)) + msk = to_numpy([c > min_conf_thr for c in confs]) + return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, + transparent_cams=transparent_cams, cam_size=cam_size, silent=silent) + + +def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state, + filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize, + win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw): + """ + from a list of images, run mast3r inference, sparse global aligner. + then run get_3D_model_from_scene + """ + imgs = load_images(filelist, size=image_size, verbose=not silent) + if len(imgs) == 1: + imgs = [imgs[0], copy.deepcopy(imgs[0])] + imgs[1]['idx'] = 1 + filelist = [filelist[0], filelist[0] + '_2'] + + scene_graph_params = [scenegraph_type] + if scenegraph_type in ["swin", "logwin"]: + scene_graph_params.append(str(winsize)) + elif scenegraph_type == "oneref": + scene_graph_params.append(str(refid)) + if scenegraph_type in ["swin", "logwin"] and not win_cyclic: + scene_graph_params.append('noncyclic') + scene_graph = '-'.join(scene_graph_params) + pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True) + if optim_level == 'coarse': + niter2 = 0 + # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation) + if current_scene_state is not None and \ + not current_scene_state.should_delete and \ + current_scene_state.cache_dir is not None: + cache_dir = current_scene_state.cache_dir + elif gradio_delete_cache: + cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir) + else: + cache_dir = os.path.join(outdir, 'cache') + os.makedirs(cache_dir, exist_ok=True) + scene = sparse_global_alignment(filelist, pairs, cache_dir, + model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device, + opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics, + matching_conf_thr=matching_conf_thr, **kw) + if current_scene_state is not None and \ + not current_scene_state.should_delete and \ + current_scene_state.outfile_name is not None: + outfile_name = current_scene_state.outfile_name + else: + outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir) + + scene_state = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name) + outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh) + return scene_state, outfile + + +def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type): + num_files = len(inputfiles) if inputfiles is not None else 1 + show_win_controls = scenegraph_type in ["swin", "logwin"] + show_winsize = scenegraph_type in ["swin", "logwin"] + show_cyclic = scenegraph_type in ["swin", "logwin"] + max_winsize, min_winsize = 1, 1 + if scenegraph_type == "swin": + if win_cyclic: + max_winsize = max(1, math.ceil((num_files - 1) / 2)) + else: + max_winsize = num_files - 1 + elif scenegraph_type == "logwin": + if win_cyclic: + half_size = math.ceil((num_files - 1) / 2) + max_winsize = max(1, math.ceil(math.log(half_size, 2))) + else: + max_winsize = max(1, math.ceil(math.log(num_files, 2))) + winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, + minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize) + win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic) + win_col = gradio.Column(visible=show_win_controls) + refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, + maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref') + return win_col, winsize, win_cyclic, refid + + +def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False, + share=False, gradio_delete_cache=False): + if not silent: + print('Outputing stuff in', tmpdirname) + + recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device, + silent, image_size) + model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent) + + def get_context(delete_cache): + css = """.gradio-container {margin: 0 !important; min-width: 100%};""" + title = "MASt3R Demo" + if delete_cache: + return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache)) + else: + return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions + + with get_context(gradio_delete_cache) as demo: + # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference + scene = gradio.State(None) + gradio.HTML('

MASt3R Demo

') + with gradio.Column(): + inputfiles = gradio.File(file_count="multiple") + with gradio.Row(): + with gradio.Column(): + with gradio.Row(): + lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01) + niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000, + label="num_iterations", info="For coarse alignment!") + lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001) + niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000, + label="num_iterations", info="For refinement!") + optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"], + value='refine+depth', label="OptLevel", + info="Optimization level") + with gradio.Row(): + matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5., + minimum=0., maximum=30., step=0.1, + info="Before Fallback to Regr3D!") + shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics", + info="Only optimize one set of intrinsics for all views") + scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"), + ("swin: sliding window", "swin"), + ("logwin: sliding window with long range", "logwin"), + ("oneref: match one image with all", "oneref")], + value='complete', label="Scenegraph", + info="Define how to make pairs", + interactive=True) + with gradio.Column(visible=False) as win_col: + winsize = gradio.Slider(label="Scene Graph: Window Size", value=1, + minimum=1, maximum=1, step=1) + win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence") + refid = gradio.Slider(label="Scene Graph: Id", value=0, + minimum=0, maximum=0, step=1, visible=False) + run_btn = gradio.Button("Run") + + with gradio.Row(): + # adjust the confidence threshold + min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1) + # adjust the camera size in the output pointcloud + cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001) + TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01) + with gradio.Row(): + as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud") + # two post process implemented + mask_sky = gradio.Checkbox(value=False, label="Mask sky") + clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") + transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") + + outmodel = gradio.Model3D() + + # events + scenegraph_type.change(set_scenegraph_options, + inputs=[inputfiles, win_cyclic, refid, scenegraph_type], + outputs=[win_col, winsize, win_cyclic, refid]) + inputfiles.change(set_scenegraph_options, + inputs=[inputfiles, win_cyclic, refid, scenegraph_type], + outputs=[win_col, winsize, win_cyclic, refid]) + win_cyclic.change(set_scenegraph_options, + inputs=[inputfiles, win_cyclic, refid, scenegraph_type], + outputs=[win_col, winsize, win_cyclic, refid]) + run_btn.click(fn=recon_fun, + inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr, + as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, + scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics], + outputs=[scene, outmodel]) + min_conf_thr.release(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh], + outputs=outmodel) + cam_size.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh], + outputs=outmodel) + TSDF_thresh.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh], + outputs=outmodel) + as_pointcloud.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh], + outputs=outmodel) + mask_sky.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh], + outputs=outmodel) + clean_depth.change(fn=model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh], + outputs=outmodel) + transparent_cams.change(model_from_scene_fun, + inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, + clean_depth, transparent_cams, cam_size, TSDF_thresh], + outputs=outmodel) + demo.launch(share=share, server_name=server_name, server_port=server_port) diff --git a/imcui/third_party/mast3r/mast3r/fast_nn.py b/imcui/third_party/mast3r/mast3r/fast_nn.py new file mode 100644 index 0000000000000000000000000000000000000000..05537f43c1be10b3733e80def8295c2ff5b5b8c0 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/fast_nn.py @@ -0,0 +1,223 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# MASt3R Fast Nearest Neighbor +# -------------------------------------------------------- +import torch +import numpy as np +import math +from scipy.spatial import KDTree + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.utils.device import to_numpy, todevice # noqa + + +@torch.no_grad() +def bruteforce_reciprocal_nns(A, B, device='cuda', block_size=None, dist='l2'): + if isinstance(A, np.ndarray): + A = torch.from_numpy(A).to(device) + if isinstance(B, np.ndarray): + B = torch.from_numpy(B).to(device) + + A = A.to(device) + B = B.to(device) + + if dist == 'l2': + dist_func = torch.cdist + argmin = torch.min + elif dist == 'dot': + def dist_func(A, B): + return A @ B.T + + def argmin(X, dim): + sim, nn = torch.max(X, dim=dim) + return sim.neg_(), nn + else: + raise ValueError(f'Unknown {dist=}') + + if block_size is None or len(A) * len(B) <= block_size**2: + dists = dist_func(A, B) + _, nn_A = argmin(dists, dim=1) + _, nn_B = argmin(dists, dim=0) + else: + dis_A = torch.full((A.shape[0],), float('inf'), device=device, dtype=A.dtype) + dis_B = torch.full((B.shape[0],), float('inf'), device=device, dtype=B.dtype) + nn_A = torch.full((A.shape[0],), -1, device=device, dtype=torch.int64) + nn_B = torch.full((B.shape[0],), -1, device=device, dtype=torch.int64) + number_of_iteration_A = math.ceil(A.shape[0] / block_size) + number_of_iteration_B = math.ceil(B.shape[0] / block_size) + + for i in range(number_of_iteration_A): + A_i = A[i * block_size:(i + 1) * block_size] + for j in range(number_of_iteration_B): + B_j = B[j * block_size:(j + 1) * block_size] + dists_blk = dist_func(A_i, B_j) # A, B, 1 + # dists_blk = dists[i * block_size:(i+1)*block_size, j * block_size:(j+1)*block_size] + min_A_i, argmin_A_i = argmin(dists_blk, dim=1) + min_B_j, argmin_B_j = argmin(dists_blk, dim=0) + + col_mask = min_A_i < dis_A[i * block_size:(i + 1) * block_size] + line_mask = min_B_j < dis_B[j * block_size:(j + 1) * block_size] + + dis_A[i * block_size:(i + 1) * block_size][col_mask] = min_A_i[col_mask] + dis_B[j * block_size:(j + 1) * block_size][line_mask] = min_B_j[line_mask] + + nn_A[i * block_size:(i + 1) * block_size][col_mask] = argmin_A_i[col_mask] + (j * block_size) + nn_B[j * block_size:(j + 1) * block_size][line_mask] = argmin_B_j[line_mask] + (i * block_size) + nn_A = nn_A.cpu().numpy() + nn_B = nn_B.cpu().numpy() + return nn_A, nn_B + + +class cdistMatcher: + def __init__(self, db_pts, device='cuda'): + self.db_pts = db_pts.to(device) + self.device = device + + def query(self, queries, k=1, **kw): + assert k == 1 + if queries.numel() == 0: + return None, [] + nnA, nnB = bruteforce_reciprocal_nns(queries, self.db_pts, device=self.device, **kw) + dis = None + return dis, nnA + + +def merge_corres(idx1, idx2, shape1=None, shape2=None, ret_xy=True, ret_index=False): + assert idx1.dtype == idx2.dtype == np.int32 + + # unique and sort along idx1 + corres = np.unique(np.c_[idx2, idx1].view(np.int64), return_index=ret_index) + if ret_index: + corres, indices = corres + xy2, xy1 = corres[:, None].view(np.int32).T + + if ret_xy: + assert shape1 and shape2 + xy1 = np.unravel_index(xy1, shape1) + xy2 = np.unravel_index(xy2, shape2) + if ret_xy != 'y_x': + xy1 = xy1[0].base[:, ::-1] + xy2 = xy2[0].base[:, ::-1] + + if ret_index: + return xy1, xy2, indices + return xy1, xy2 + + +def fast_reciprocal_NNs(pts1, pts2, subsample_or_initxy1=8, ret_xy=True, pixel_tol=0, ret_basin=False, + device='cuda', **matcher_kw): + H1, W1, DIM1 = pts1.shape + H2, W2, DIM2 = pts2.shape + assert DIM1 == DIM2 + + pts1 = pts1.reshape(-1, DIM1) + pts2 = pts2.reshape(-1, DIM2) + + if isinstance(subsample_or_initxy1, int) and pixel_tol == 0: + S = subsample_or_initxy1 + y1, x1 = np.mgrid[S // 2:H1:S, S // 2:W1:S].reshape(2, -1) + max_iter = 10 + else: + x1, y1 = subsample_or_initxy1 + if isinstance(x1, torch.Tensor): + x1 = x1.cpu().numpy() + if isinstance(y1, torch.Tensor): + y1 = y1.cpu().numpy() + max_iter = 1 + + xy1 = np.int32(np.unique(x1 + W1 * y1)) # make sure there's no doublons + xy2 = np.full_like(xy1, -1) + old_xy1 = xy1.copy() + old_xy2 = xy2.copy() + + if 'dist' in matcher_kw or 'block_size' in matcher_kw \ + or (isinstance(device, str) and device.startswith('cuda')) \ + or (isinstance(device, torch.device) and device.type.startswith('cuda')): + pts1 = pts1.to(device) + pts2 = pts2.to(device) + tree1 = cdistMatcher(pts1, device=device) + tree2 = cdistMatcher(pts2, device=device) + else: + pts1, pts2 = to_numpy((pts1, pts2)) + tree1 = KDTree(pts1) + tree2 = KDTree(pts2) + + notyet = np.ones(len(xy1), dtype=bool) + basin = np.full((H1 * W1 + 1,), -1, dtype=np.int32) if ret_basin else None + + niter = 0 + # n_notyet = [len(notyet)] + while notyet.any(): + _, xy2[notyet] = to_numpy(tree2.query(pts1[xy1[notyet]], **matcher_kw)) + if not ret_basin: + notyet &= (old_xy2 != xy2) # remove points that have converged + + _, xy1[notyet] = to_numpy(tree1.query(pts2[xy2[notyet]], **matcher_kw)) + if ret_basin: + basin[old_xy1[notyet]] = xy1[notyet] + notyet &= (old_xy1 != xy1) # remove points that have converged + + # n_notyet.append(notyet.sum()) + niter += 1 + if niter >= max_iter: + break + + old_xy2[:] = xy2 + old_xy1[:] = xy1 + + # print('notyet_stats:', ' '.join(map(str, (n_notyet+[0]*10)[:max_iter]))) + + if pixel_tol > 0: + # in case we only want to match some specific points + # and still have some way of checking reciprocity + old_yx1 = np.unravel_index(old_xy1, (H1, W1))[0].base + new_yx1 = np.unravel_index(xy1, (H1, W1))[0].base + dis = np.linalg.norm(old_yx1 - new_yx1, axis=-1) + converged = dis < pixel_tol + if not isinstance(subsample_or_initxy1, int): + xy1 = old_xy1 # replace new points by old ones + else: + converged = ~notyet # converged correspondences + + # keep only unique correspondences, and sort on xy1 + xy1, xy2 = merge_corres(xy1[converged], xy2[converged], (H1, W1), (H2, W2), ret_xy=ret_xy) + if ret_basin: + return xy1, xy2, basin + return xy1, xy2 + + +def extract_correspondences_nonsym(A, B, confA, confB, subsample=8, device=None, ptmap_key='pred_desc', pixel_tol=0): + if '3d' in ptmap_key: + opt = dict(device='cpu', workers=32) + else: + opt = dict(device=device, dist='dot', block_size=2**13) + + # matching the two pairs + idx1 = [] + idx2 = [] + # merge corres from opposite pairs + HA, WA = A.shape[:2] + HB, WB = B.shape[:2] + if pixel_tol == 0: + nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt) + nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt) + else: + S = subsample + yA, xA = np.mgrid[S // 2:HA:S, S // 2:WA:S].reshape(2, -1) + yB, xB = np.mgrid[S // 2:HB:S, S // 2:WB:S].reshape(2, -1) + + nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=(xA, yA), ret_xy=False, pixel_tol=pixel_tol, **opt) + nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=(xB, yB), ret_xy=False, pixel_tol=pixel_tol, **opt) + + idx1 = np.r_[nn1to2[0], nn2to1[1]] + idx2 = np.r_[nn1to2[1], nn2to1[0]] + + c1 = confA.ravel()[idx1] + c2 = confB.ravel()[idx2] + + xy1, xy2, idx = merge_corres(idx1, idx2, (HA, WA), (HB, WB), ret_xy=True, ret_index=True) + conf = np.minimum(c1[idx], c2[idx]) + corres = (xy1.copy(), xy2.copy(), conf) + return todevice(corres, device) diff --git a/imcui/third_party/mast3r/mast3r/losses.py b/imcui/third_party/mast3r/mast3r/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..3a50f57481e436d7752dcbf2b414be3ea65ee76b --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/losses.py @@ -0,0 +1,508 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Implementation of MASt3R training losses +# -------------------------------------------------------- +import torch +import torch.nn as nn +import numpy as np +from sklearn.metrics import average_precision_score + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.losses import BaseCriterion, Criterion, MultiLoss, Sum, ConfLoss +from dust3r.losses import Regr3D as Regr3D_dust3r +from dust3r.utils.geometry import (geotrf, inv, normalize_pointcloud) +from dust3r.inference import get_pred_pts3d +from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale + + +def apply_log_to_norm(xyz): + d = xyz.norm(dim=-1, keepdim=True) + xyz = xyz / d.clip(min=1e-8) + xyz = xyz * torch.log1p(d) + return xyz + + +class Regr3D (Regr3D_dust3r): + def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, opt_fit_gt=False, + sky_loss_value=2, max_metric_scale=False, loss_in_log=False): + self.loss_in_log = loss_in_log + if norm_mode.startswith('?'): + # do no norm pts from metric scale datasets + self.norm_all = False + self.norm_mode = norm_mode[1:] + else: + self.norm_all = True + self.norm_mode = norm_mode + super().__init__(criterion, self.norm_mode, gt_scale) + + self.sky_loss_value = sky_loss_value + self.max_metric_scale = max_metric_scale + + def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): + # everything is normalized w.r.t. camera of view1 + in_camera1 = inv(gt1['camera_pose']) + gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3 + gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3 + + valid1 = gt1['valid_mask'].clone() + valid2 = gt2['valid_mask'].clone() + + if dist_clip is not None: + # points that are too far-away == invalid + dis1 = gt_pts1.norm(dim=-1) # (B, H, W) + dis2 = gt_pts2.norm(dim=-1) # (B, H, W) + valid1 = valid1 & (dis1 <= dist_clip) + valid2 = valid2 & (dis2 <= dist_clip) + + if self.loss_in_log == 'before': + # this only make sense when depth_mode == 'linear' + gt_pts1 = apply_log_to_norm(gt_pts1) + gt_pts2 = apply_log_to_norm(gt_pts2) + + pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False).clone() + pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True).clone() + + if not self.norm_all: + if self.max_metric_scale: + B = valid1.shape[0] + # valid1: B, H, W + # torch.linalg.norm(gt_pts1, dim=-1) -> B, H, W + # dist1_to_cam1 -> reshape to B, H*W + dist1_to_cam1 = torch.where(valid1, torch.linalg.norm(gt_pts1, dim=-1), 0).view(B, -1) + dist2_to_cam1 = torch.where(valid2, torch.linalg.norm(gt_pts2, dim=-1), 0).view(B, -1) + + # is_metric_scale: B + # dist1_to_cam1.max(dim=-1).values -> B + gt1['is_metric_scale'] = gt1['is_metric_scale'] \ + & (dist1_to_cam1.max(dim=-1).values < self.max_metric_scale) \ + & (dist2_to_cam1.max(dim=-1).values < self.max_metric_scale) + gt2['is_metric_scale'] = gt1['is_metric_scale'] + + mask = ~gt1['is_metric_scale'] + else: + mask = torch.ones_like(gt1['is_metric_scale']) + # normalize 3d points + if self.norm_mode and mask.any(): + pr_pts1[mask], pr_pts2[mask] = normalize_pointcloud(pr_pts1[mask], pr_pts2[mask], self.norm_mode, + valid1[mask], valid2[mask]) + + if self.norm_mode and not self.gt_scale: + gt_pts1, gt_pts2, norm_factor = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, + valid1, valid2, ret_factor=True) + # apply the same normalization to prediction + pr_pts1[~mask] = pr_pts1[~mask] / norm_factor[~mask] + pr_pts2[~mask] = pr_pts2[~mask] / norm_factor[~mask] + + # return sky segmentation, making sure they don't include any labelled 3d points + sky1 = gt1['sky_mask'] & (~valid1) + sky2 = gt2['sky_mask'] & (~valid2) + return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, sky1, sky2, {} + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \ + self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) + + if self.sky_loss_value > 0: + assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss' + # add the sky pixel as "valid" pixels... + mask1 = mask1 | sky1 + mask2 = mask2 | sky2 + + # loss on img1 side + pred_pts1 = pred_pts1[mask1] + gt_pts1 = gt_pts1[mask1] + if self.loss_in_log and self.loss_in_log != 'before': + # this only make sense when depth_mode == 'exp' + pred_pts1 = apply_log_to_norm(pred_pts1) + gt_pts1 = apply_log_to_norm(gt_pts1) + l1 = self.criterion(pred_pts1, gt_pts1) + + # loss on gt2 side + pred_pts2 = pred_pts2[mask2] + gt_pts2 = gt_pts2[mask2] + if self.loss_in_log and self.loss_in_log != 'before': + pred_pts2 = apply_log_to_norm(pred_pts2) + gt_pts2 = apply_log_to_norm(gt_pts2) + l2 = self.criterion(pred_pts2, gt_pts2) + + if self.sky_loss_value > 0: + assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss' + # ... but force the loss to be high there + l1 = torch.where(sky1[mask1], self.sky_loss_value, l1) + l2 = torch.where(sky2[mask2], self.sky_loss_value, l2) + self_name = type(self).__name__ + details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())} + return Sum((l1, mask1), (l2, mask2)), (details | monitoring) + + +class Regr3D_ShiftInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute unnormalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \ + super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # compute median depth + gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] + pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] + gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] + pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] + + # subtract the median depth + gt_z1 -= gt_shift_z + gt_z2 -= gt_shift_z + pred_z1 -= pred_shift_z + pred_z2 -= pred_shift_z + + # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach()) + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring + + +class Regr3D_ScaleInv (Regr3D): + """ Same than Regr3D but invariant to depth scale. + if gt_scale == True: enforce the prediction to take the same scale than GT + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute depth-normalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \ + super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # measure scene scale + _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) + _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) + + # prevent predictions to be in a ridiculous range + pred_scale = pred_scale.clip(min=1e-3, max=1e3) + + # subtract the median depth + if self.gt_scale: + pred_pts1 *= gt_scale / pred_scale + pred_pts2 *= gt_scale / pred_scale + # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean()) + else: + gt_pts1 /= gt_scale + gt_pts2 /= gt_scale + pred_pts1 /= pred_scale + pred_pts2 /= pred_scale + # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()) + + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring + + +class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): + # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv + pass + + +def get_similarities(desc1, desc2, euc=False): + if euc: # euclidean distance in same range than similarities + dists = (desc1[:, :, None] - desc2[:, None]).norm(dim=-1) + sim = 1 / (1 + dists) + else: + # Compute similarities + sim = desc1 @ desc2.transpose(-2, -1) + return sim + + +class MatchingCriterion(BaseCriterion): + def __init__(self, reduction='mean', fp=torch.float32): + super().__init__(reduction) + self.fp = fp + + def forward(self, a, b, valid_matches=None, euc=False): + assert a.ndim >= 2 and 1 <= a.shape[-1], f'Bad shape = {a.shape}' + dist = self.loss(a.to(self.fp), b.to(self.fp), valid_matches, euc=euc) + # one dimension less or reduction to single value + assert (valid_matches is None and dist.ndim == a.ndim - + 1) or self.reduction in ['mean', 'sum', '1-mean', 'none'] + if self.reduction == 'none': + return dist + if self.reduction == 'sum': + return dist.sum() + if self.reduction == 'mean': + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + if self.reduction == '1-mean': + return 1. - dist.mean() if dist.numel() > 0 else dist.new_ones(()) + raise ValueError(f'bad {self.reduction=} mode') + + def loss(self, a, b, valid_matches=None): + raise NotImplementedError + + +class InfoNCE(MatchingCriterion): + def __init__(self, temperature=0.07, eps=1e-8, mode='all', **kwargs): + super().__init__(**kwargs) + self.temperature = temperature + self.eps = eps + assert mode in ['all', 'proper', 'dual'] + self.mode = mode + + def loss(self, desc1, desc2, valid_matches=None, euc=False): + # valid positives are along diagonals + B, N, D = desc1.shape + B2, N2, D2 = desc2.shape + assert B == B2 and D == D2 + if valid_matches is None: + valid_matches = torch.ones([B, N], dtype=bool) + # torch.all(valid_matches.sum(dim=-1) > 0) some pairs have no matches???? + assert valid_matches.shape == torch.Size([B, N]) and valid_matches.sum() > 0 + + # Tempered similarities + sim = get_similarities(desc1, desc2, euc) / self.temperature + sim[sim.isnan()] = -torch.inf # ignore nans + # Softmax of positives with temperature + sim = sim.exp_() # save peak memory + positives = sim.diagonal(dim1=-2, dim2=-1) + + # Loss + if self.mode == 'all': # Previous InfoNCE + loss = -torch.log((positives / sim.sum(dim=-1).sum(dim=-1, keepdim=True)).clip(self.eps)) + elif self.mode == 'proper': # Proper InfoNCE + loss = -(torch.log((positives / sim.sum(dim=-2)).clip(self.eps)) + + torch.log((positives / sim.sum(dim=-1)).clip(self.eps))) + elif self.mode == 'dual': # Dual Softmax + loss = -(torch.log((positives**2 / sim.sum(dim=-1) / sim.sum(dim=-2)).clip(self.eps))) + else: + raise ValueError("This should not happen...") + return loss[valid_matches] + + +class APLoss (MatchingCriterion): + """ AP loss. + """ + + def __init__(self, nq='torch', min=0, max=1, euc=False, **kw): + super().__init__(**kw) + # Exact/True AP loss (not differentiable) + if nq == 0: + nq = 'sklearn' # special case + try: + self.compute_AP = eval('self.compute_true_AP_' + nq) + except: + raise ValueError("Unknown mode %s for AP loss" % nq) + + @staticmethod + def compute_true_AP_sklearn(scores, labels): + def compute_AP(label, score): + return average_precision_score(label, score) + + aps = scores.new_zeros((scores.shape[0], scores.shape[1])) + label_np = labels.cpu().numpy().astype(bool) + scores_np = scores.cpu().numpy() + for bi in range(scores_np.shape[0]): + for i in range(scores_np.shape[1]): + labels = label_np[bi, i, :] + if labels.sum() < 1: + continue + aps[bi, i] = compute_AP(labels, scores_np[bi, i, :]) + return aps + + @staticmethod + def compute_true_AP_torch(scores, labels): + assert scores.shape == labels.shape + B, N, M = labels.shape + dev = labels.device + with torch.no_grad(): + # sort scores + _, order = scores.sort(dim=-1, descending=True) + # sort labels accordingly + labels = labels[torch.arange(B, device=dev)[:, None, None].expand(order.shape), + torch.arange(N, device=dev)[None, :, None].expand(order.shape), + order] + # compute number of positives per query + npos = labels.sum(dim=-1) + assert torch.all(torch.isclose(npos, npos[0, 0]) + ), "only implemented for constant number of positives per query" + npos = int(npos[0, 0]) + # compute precision at each recall point + posrank = labels.nonzero()[:, -1].view(B, N, npos) + recall = torch.arange(1, 1 + npos, dtype=torch.float32, device=dev)[None, None, :].expand(B, N, npos) + precision = recall / (1 + posrank).float() + # average precision values at all recall points + aps = precision.mean(dim=-1) + + return aps + + def loss(self, desc1, desc2, valid_matches=None, euc=False): # if matches is None, positives are the diagonal + B, N1, D = desc1.shape + B2, N2, D2 = desc2.shape + assert B == B2 and D == D2 + + scores = get_similarities(desc1, desc2, euc) + + labels = torch.zeros([B, N1, N2], dtype=scores.dtype, device=scores.device) + + # allow all diagonal positives and only mask afterwards + labels.diagonal(dim1=-2, dim2=-1)[...] = 1. + apscore = self.compute_AP(scores, labels) + if valid_matches is not None: + apscore = apscore[valid_matches] + return apscore + + +class MatchingLoss (Criterion, MultiLoss): + """ + Matching loss per image + only compare pixels inside an image but not in the whole batch as what would be done usually + """ + + def __init__(self, criterion, withconf=False, use_pts3d=False, negatives_padding=0, blocksize=4096): + super().__init__(criterion) + self.negatives_padding = negatives_padding + self.use_pts3d = use_pts3d + self.blocksize = blocksize + self.withconf = withconf + + def add_negatives(self, outdesc2, desc2, batchid, x2, y2): + if self.negatives_padding: + B, H, W, D = desc2.shape + negatives = torch.ones([B, H, W], device=desc2.device, dtype=bool) + negatives[batchid, y2, x2] = False + sel = negatives & (negatives.view([B, -1]).cumsum(dim=-1).view(B, H, W) + <= self.negatives_padding) # take the N-first negatives + outdesc2 = torch.cat([outdesc2, desc2[sel].view([B, -1, D])], dim=1) + return outdesc2 + + def get_confs(self, pred1, pred2, sel1, sel2): + if self.withconf: + if self.use_pts3d: + outconfs1 = pred1['conf'][sel1] + outconfs2 = pred2['conf'][sel2] + else: + outconfs1 = pred1['desc_conf'][sel1] + outconfs2 = pred2['desc_conf'][sel2] + else: + outconfs1 = outconfs2 = None + return outconfs1, outconfs2 + + def get_descs(self, pred1, pred2): + if self.use_pts3d: + desc1, desc2 = pred1['pts3d'], pred2['pts3d_in_other_view'] + else: + desc1, desc2 = pred1['desc'], pred2['desc'] + return desc1, desc2 + + def get_matching_descs(self, gt1, gt2, pred1, pred2, **kw): + outdesc1 = outdesc2 = outconfs1 = outconfs2 = None + # Recover descs, GT corres and valid mask + desc1, desc2 = self.get_descs(pred1, pred2) + + (x1, y1), (x2, y2) = gt1['corres'].unbind(-1), gt2['corres'].unbind(-1) + valid_matches = gt1['valid_corres'] + + # Select descs that have GT matches + B, N = x1.shape + batchid = torch.arange(B)[:, None].repeat(1, N) # B, N + outdesc1, outdesc2 = desc1[batchid, y1, x1], desc2[batchid, y2, x2] # B, N, D + + # Padd with unused negatives + outdesc2 = self.add_negatives(outdesc2, desc2, batchid, x2, y2) + + # Gather confs if needed + sel1 = batchid, y1, x1 + sel2 = batchid, y2, x2 + outconfs1, outconfs2 = self.get_confs(pred1, pred2, sel1, sel2) + + return outdesc1, outdesc2, outconfs1, outconfs2, valid_matches, {'use_euclidean_dist': self.use_pts3d} + + def blockwise_criterion(self, descs1, descs2, confs1, confs2, valid_matches, euc, rng=np.random, shuffle=True): + loss = None + details = {} + B, N, D = descs1.shape + + if N <= self.blocksize: # Blocks are larger than provided descs, compute regular loss + loss = self.criterion(descs1, descs2, valid_matches, euc=euc) + else: # Compute criterion on the blockdiagonal only, after shuffling + # Shuffle if necessary + matches_perm = slice(None) + if shuffle: + matches_perm = np.stack([rng.choice(range(N), size=N, replace=False) for _ in range(B)]) + batchid = torch.tile(torch.arange(B), (N, 1)).T + matches_perm = batchid, matches_perm + + descs1 = descs1[matches_perm] + descs2 = descs2[matches_perm] + valid_matches = valid_matches[matches_perm] + + assert N % self.blocksize == 0, "Error, can't chunk block-diagonal, please check blocksize" + n_chunks = N // self.blocksize + descs1 = descs1.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D] + descs2 = descs2.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D] + valid_matches = valid_matches.view([B * n_chunks, self.blocksize]) + loss = self.criterion(descs1, descs2, valid_matches, euc=euc) + if self.withconf: + confs1, confs2 = map(lambda x: x[matches_perm], (confs1, confs2)) # apply perm to confidences if needed + + if self.withconf: + # split confidences between positives/negatives for loss computation + details['conf_pos'] = map(lambda x: x[valid_matches.view(B, -1)], (confs1, confs2)) + details['conf_neg'] = map(lambda x: x[~valid_matches.view(B, -1)], (confs1, confs2)) + details['Conf1_std'] = confs1.std() + details['Conf2_std'] = confs2.std() + + return loss, details + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + # Gather preds and GT + descs1, descs2, confs1, confs2, valid_matches, monitoring = self.get_matching_descs( + gt1, gt2, pred1, pred2, **kw) + + # loss on matches + loss, details = self.blockwise_criterion(descs1, descs2, confs1, confs2, + valid_matches, euc=monitoring.pop('use_euclidean_dist', False)) + + details[type(self).__name__] = float(loss.mean()) + return loss, (details | monitoring) + + +class ConfMatchingLoss(ConfLoss): + """ Weight matching by learned confidence. Same as ConfLoss but for a matching criterion + Assuming the input matching_loss is a match-level loss. + """ + + def __init__(self, pixel_loss, alpha=1., confmode='prod', neg_conf_loss_quantile=False): + super().__init__(pixel_loss, alpha) + self.pixel_loss.withconf = True + self.confmode = confmode + self.neg_conf_loss_quantile = neg_conf_loss_quantile + + def aggregate_confs(self, confs1, confs2): # get the confidences resulting from the two view predictions + if self.confmode == 'prod': + confs = confs1 * confs2 if confs1 is not None and confs2 is not None else 1. + elif self.confmode == 'mean': + confs = .5 * (confs1 + confs2) if confs1 is not None and confs2 is not None else 1. + else: + raise ValueError(f"Unknown conf mode {self.confmode}") + return confs + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + # compute per-pixel loss + loss, details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) + # Recover confidences for positive and negative samples + conf1_pos, conf2_pos = details.pop('conf_pos') + conf1_neg, conf2_neg = details.pop('conf_neg') + conf_pos = self.aggregate_confs(conf1_pos, conf2_pos) + + # weight Matching loss by confidence on positives + conf_pos, log_conf_pos = self.get_conf_log(conf_pos) + conf_loss = loss * conf_pos - self.alpha * log_conf_pos + # average + nan protection (in case of no valid pixels at all) + conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0 + # Add negative confs loss to give some supervision signal to confidences for pixels that are not matched in GT + if self.neg_conf_loss_quantile: + conf_neg = torch.cat([conf1_neg, conf2_neg]) + conf_neg, log_conf_neg = self.get_conf_log(conf_neg) + + # recover quantile that will be used for negatives loss value assignment + neg_loss_value = torch.quantile(loss, self.neg_conf_loss_quantile).detach() + neg_loss = neg_loss_value * conf_neg - self.alpha * log_conf_neg + + neg_loss = neg_loss.mean() if neg_loss.numel() > 0 else 0 + conf_loss = conf_loss + neg_loss + + return conf_loss, dict(matching_conf_loss=float(conf_loss), **details) diff --git a/imcui/third_party/mast3r/mast3r/model.py b/imcui/third_party/mast3r/mast3r/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f328c5e43b8e98f2ec960e4d25e6f235ac543544 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/model.py @@ -0,0 +1,68 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# MASt3R model class +# -------------------------------------------------------- +import torch +import torch.nn.functional as F +import os + +from mast3r.catmlp_dpt_head import mast3r_head_factory + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.model import AsymmetricCroCo3DStereo # noqa +from dust3r.utils.misc import transpose_to_landscape # noqa + + +inf = float('inf') + + +def load_model(model_path, device, verbose=True): + if verbose: + print('... loading model from', model_path) + ckpt = torch.load(model_path, map_location='cpu') + args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if 'landscape_only' not in args: + args = args[:-1] + ', landscape_only=False)' + else: + args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') + assert "landscape_only=False" in args + if verbose: + print(f"instantiating : {args}") + net = eval(args) + s = net.load_state_dict(ckpt['model'], strict=False) + if verbose: + print(s) + return net.to(device) + + +class AsymmetricMASt3R(AsymmetricCroCo3DStereo): + def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): + self.desc_mode = desc_mode + self.two_confs = two_confs + self.desc_conf_mode = desc_conf_mode + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device='cpu') + else: + return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) + + def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): + assert img_size[0] % patch_size == 0 and img_size[ + 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + if self.desc_conf_mode is None: + self.desc_conf_mode = conf_mode + # allocate heads + self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + # magic wrapper + self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) + self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) diff --git a/imcui/third_party/mast3r/mast3r/utils/__init__.py b/imcui/third_party/mast3r/mast3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dd877d649ce4dbd749dd7195a8b34c0f91d4f0 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). \ No newline at end of file diff --git a/imcui/third_party/mast3r/mast3r/utils/coarse_to_fine.py b/imcui/third_party/mast3r/mast3r/utils/coarse_to_fine.py new file mode 100644 index 0000000000000000000000000000000000000000..c062e8608f82c608f2d605d69a95a7e0f301b3cf --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/utils/coarse_to_fine.py @@ -0,0 +1,214 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# coarse to fine utilities +# -------------------------------------------------------- +import numpy as np + + +def crop_tag(cell): + return f'[{cell[1]}:{cell[3]},{cell[0]}:{cell[2]}]' + + +def crop_slice(cell): + return slice(cell[1], cell[3]), slice(cell[0], cell[2]) + + +def _start_pos(total_size, win_size, overlap): + # we must have AT LEAST overlap between segments + # first segment starts at 0, last segment starts at total_size-win_size + assert 0 <= overlap < 1 + assert total_size >= win_size + spacing = win_size * (1 - overlap) + last_pt = total_size - win_size + n_windows = 2 + int((last_pt - 1) // spacing) + return np.linspace(0, last_pt, n_windows).round().astype(int) + + +def multiple_of_16(x): + return (x // 16) * 16 + + +def _make_overlapping_grid(H, W, size, overlap): + H_win = multiple_of_16(H * size // max(H, W)) + W_win = multiple_of_16(W * size // max(H, W)) + x = _start_pos(W, W_win, overlap) + y = _start_pos(H, H_win, overlap) + grid = np.stack(np.meshgrid(x, y, indexing='xy'), axis=-1) + grid = np.concatenate((grid, grid + (W_win, H_win)), axis=-1) + return grid.reshape(-1, 4) + + +def _cell_size(cell2): + width, height = cell2[:, 2] - cell2[:, 0], cell2[:, 3] - cell2[:, 1] + assert width.min() >= 0 + assert height.min() >= 0 + return width, height + + +def _norm_windows(cell2, H2, W2, forced_resolution=None): + # make sure the window aspect ratio is 3/4, or the output resolution is forced_resolution if defined + outcell = cell2.copy() + width, height = _cell_size(cell2) + width2, height2 = width.clip(max=W2), height.clip(max=H2) + if forced_resolution is None: + width2[width < height] = (height2[width < height] * 3.01 / 4).clip(max=W2) + height2[width >= height] = (width2[width >= height] * 3.01 / 4).clip(max=H2) + else: + forced_H, forced_W = forced_resolution + width2[:] = forced_W + height2[:] = forced_H + + half = (width2 - width) / 2 + outcell[:, 0] -= half + outcell[:, 2] += half + half = (height2 - height) / 2 + outcell[:, 1] -= half + outcell[:, 3] += half + + # proj to integers + outcell = np.floor(outcell).astype(int) + # Take care of flooring errors + tmpw, tmph = _cell_size(outcell) + outcell[:, 0] += tmpw.astype(tmpw.dtype) - width2.astype(tmpw.dtype) + outcell[:, 1] += tmph.astype(tmpw.dtype) - height2.astype(tmpw.dtype) + + # make sure 0 <= x < W2 and 0 <= y < H2 + outcell[:, 0::2] -= outcell[:, [0]].clip(max=0) + outcell[:, 1::2] -= outcell[:, [1]].clip(max=0) + outcell[:, 0::2] -= outcell[:, [2]].clip(min=W2) - W2 + outcell[:, 1::2] -= outcell[:, [3]].clip(min=H2) - H2 + + width, height = _cell_size(outcell) + assert np.all(width == width2.astype(width.dtype)) and np.all( + height == height2.astype(height.dtype)), "Error, output is not of the expected shape." + assert np.all(width <= W2) + assert np.all(height <= H2) + return outcell + + +def _weight_pixels(cell, pix, assigned, gauss_var=2): + center = cell.reshape(-1, 2, 2).mean(axis=1) + width, height = _cell_size(cell) + + # square distance between each cell center and each point + dist = (center[:, None] - pix[None]) / np.c_[width, height][:, None] + dist2 = np.square(dist).sum(axis=-1) + + assert assigned.shape == dist2.shape + res = np.where(assigned, np.exp(-gauss_var * dist2), 0) + return res + + +def pos2d_in_rect(p1, cell1): + x, y = p1.T + l, t, r, b = cell1 + assigned = (l <= x) & (x < r) & (t <= y) & (y < b) + return assigned + + +def _score_cell(cell1, H2, W2, p1, p2, min_corres=10, forced_resolution=None): + assert p1.shape == p2.shape + + # compute keypoint assignment + assigned = pos2d_in_rect(p1, cell1[None].T) + assert assigned.shape == (len(cell1), len(p1)) + + # remove cells without correspondences + valid_cells = assigned.sum(axis=1) >= min_corres + cell1 = cell1[valid_cells] + assigned = assigned[valid_cells] + if not valid_cells.any(): + return cell1, cell1, assigned + + # fill-in the assigned points in both image + assigned_p1 = np.empty((len(cell1), len(p1), 2), dtype=np.float32) + assigned_p2 = np.empty((len(cell1), len(p2), 2), dtype=np.float32) + assigned_p1[:] = p1[None] + assigned_p2[:] = p2[None] + assigned_p1[~assigned] = np.nan + assigned_p2[~assigned] = np.nan + + # find the median center and scale of assigned points in each cell + # cell_center1 = np.nanmean(assigned_p1, axis=1) + cell_center2 = np.nanmean(assigned_p2, axis=1) + im1_q25, im1_q75 = np.nanquantile(assigned_p1, (0.1, 0.9), axis=1) + im2_q25, im2_q75 = np.nanquantile(assigned_p2, (0.1, 0.9), axis=1) + + robust_std1 = (im1_q75 - im1_q25).clip(20.) + robust_std2 = (im2_q75 - im2_q25).clip(20.) + + cell_size1 = (cell1[:, 2:4] - cell1[:, 0:2]) + cell_size2 = cell_size1 * robust_std2 / robust_std1 + cell2 = np.c_[cell_center2 - cell_size2 / 2, cell_center2 + cell_size2 / 2] + + # make sure cell bounds are valid + cell2 = _norm_windows(cell2, H2, W2, forced_resolution=forced_resolution) + + # compute correspondence weights + corres_weights = _weight_pixels(cell1, p1, assigned) * _weight_pixels(cell2, p2, assigned) + + # return a list of window pairs and assigned correspondences + return cell1, cell2, corres_weights + + +def greedy_selection(corres_weights, target=0.9): + # corres_weight = (n_cell_pair, n_corres) matrix. + # If corres_weight[c,p]>0, means that correspondence p is visible in cell pair p + assert 0 < target <= 1 + corres_weights = corres_weights.copy() + + total = corres_weights.max(axis=0).sum() + target *= total + + # init = empty + res = [] + cur = np.zeros(corres_weights.shape[1]) # current selection + + while cur.sum() < target: + # pick the nex best cell pair + best = corres_weights.sum(axis=1).argmax() + res.append(best) + + # update current + cur += corres_weights[best] + # print('appending', best, 'with score', corres_weights[best].sum(), '-->', cur.sum()) + + # remove from all other views + corres_weights = (corres_weights - corres_weights[best]).clip(min=0) + + return res + + +def select_pairs_of_crops(img_q, img_b, pos2d_in_query, pos2d_in_ref, maxdim=512, overlap=.5, forced_resolution=None): + # prepare the overlapping cells + grid_q = _make_overlapping_grid(*img_q.shape[:2], maxdim, overlap) + grid_b = _make_overlapping_grid(*img_b.shape[:2], maxdim, overlap) + + assert forced_resolution is None or len(forced_resolution) == 2 + if isinstance(forced_resolution[0], int) or not len(forced_resolution[0]) == 2: + forced_resolution1 = forced_resolution2 = forced_resolution + else: + assert len(forced_resolution[1]) == 2 + forced_resolution1 = forced_resolution[0] + forced_resolution2 = forced_resolution[1] + + # Make sure crops respect constraints + grid_q = _norm_windows(grid_q.astype(float), *img_q.shape[:2], forced_resolution=forced_resolution1) + grid_b = _norm_windows(grid_b.astype(float), *img_b.shape[:2], forced_resolution=forced_resolution2) + + # score cells + pairs_q = _score_cell(grid_q, *img_b.shape[:2], pos2d_in_query, pos2d_in_ref, forced_resolution=forced_resolution2) + pairs_b = _score_cell(grid_b, *img_q.shape[:2], pos2d_in_ref, pos2d_in_query, forced_resolution=forced_resolution1) + pairs_b = pairs_b[1], pairs_b[0], pairs_b[2] # cellq, cellb, corres_weights + + # greedy selection until all correspondences are generated + cell1, cell2, corres_weights = map(np.concatenate, zip(pairs_q, pairs_b)) + if len(corres_weights) == 0: + return # tolerated for empty generators + order = greedy_selection(corres_weights, target=0.9) + + for i in order: + def pair_tag(qi, bi): return (str(qi) + crop_tag(cell1[i]), str(bi) + crop_tag(cell2[i])) + yield cell1[i], cell2[i], pair_tag diff --git a/imcui/third_party/mast3r/mast3r/utils/collate.py b/imcui/third_party/mast3r/mast3r/utils/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..72ee3a437b87ef7049dcd03b93e594a8325b780c --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/utils/collate.py @@ -0,0 +1,62 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Collate extensions +# -------------------------------------------------------- + +import torch +import collections +from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format +from typing import Callable, Dict, Optional, Tuple, Type, Union, List + + +def cat_collate_tensor_fn(batch, *, collate_fn_map): + return torch.cat(batch, dim=0) + + +def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): + return [item for bb in batch for item in bb] # concatenate all lists + + +cat_collate_fn_map = default_collate_fn_map.copy() +cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn +cat_collate_fn_map[List] = cat_collate_list_fn +cat_collate_fn_map[type(None)] = lambda _, **kw: None # When some Nones, simply return a single None + + +def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): + r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """ + elem = batch[0] + elem_type = type(elem) + + if collate_fn_map is not None: + if elem_type in collate_fn_map: + return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) + + for collate_type in collate_fn_map: + if isinstance(elem, collate_type): + return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map) + + if isinstance(elem, collections.abc.Mapping): + try: + return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) + except TypeError: + # The mapping type may not support `__init__(iterable)`. + return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. + + if isinstance(elem, tuple): + # Backwards compatibility. + return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] + else: + try: + return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]) + except TypeError: + # The sequence type may not support `__init__(iterable)` (e.g., `range`). + return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/imcui/third_party/mast3r/mast3r/utils/misc.py b/imcui/third_party/mast3r/mast3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5403c67116f5156e47537df8fbcfcacb7bb474 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/utils/misc.py @@ -0,0 +1,17 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for MASt3R +# -------------------------------------------------------- +import os +import hashlib + + +def mkdir_for(f): + os.makedirs(os.path.dirname(f), exist_ok=True) + return f + + +def hash_md5(s): + return hashlib.md5(s.encode('utf-8')).hexdigest() diff --git a/imcui/third_party/mast3r/mast3r/utils/path_to_dust3r.py b/imcui/third_party/mast3r/mast3r/utils/path_to_dust3r.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfd78cb432ef45ee30bb893f7f90fff08474b93 --- /dev/null +++ b/imcui/third_party/mast3r/mast3r/utils/path_to_dust3r.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dust3r submodule import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +DUSt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../dust3r')) +DUSt3R_LIB_PATH = path.join(DUSt3R_REPO_PATH, 'dust3r') +# check the presence of models directory in repo to be sure its cloned +if path.isdir(DUSt3R_LIB_PATH): + # workaround for sibling import + sys.path.insert(0, DUSt3R_REPO_PATH) +else: + raise ImportError(f"dust3r is not initialized, could not find: {DUSt3R_LIB_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?") diff --git a/imcui/third_party/mast3r/train.py b/imcui/third_party/mast3r/train.py new file mode 100644 index 0000000000000000000000000000000000000000..57a7689c408f63701cf4287ada50802f1a5cef6a --- /dev/null +++ b/imcui/third_party/mast3r/train.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# training executable for MASt3R +# -------------------------------------------------------- +from mast3r.model import AsymmetricMASt3R +from mast3r.losses import ConfMatchingLoss, MatchingLoss, APLoss, Regr3D, InfoNCE, Regr3D_ScaleShiftInv +from mast3r.datasets import ARKitScenes, BlendedMVS, Co3d, MegaDepth, ScanNetpp, StaticThings3D, Waymo, WildRGBD + +import mast3r.utils.path_to_dust3r # noqa +# add mast3r classes to dust3r imports +import dust3r.training +dust3r.training.AsymmetricMASt3R = AsymmetricMASt3R +dust3r.training.Regr3D = Regr3D +dust3r.training.Regr3D_ScaleShiftInv = Regr3D_ScaleShiftInv +dust3r.training.MatchingLoss = MatchingLoss +dust3r.training.ConfMatchingLoss = ConfMatchingLoss +dust3r.training.InfoNCE = InfoNCE +dust3r.training.APLoss = APLoss + +import dust3r.datasets +dust3r.datasets.ARKitScenes = ARKitScenes +dust3r.datasets.BlendedMVS = BlendedMVS +dust3r.datasets.Co3d = Co3d +dust3r.datasets.MegaDepth = MegaDepth +dust3r.datasets.ScanNetpp = ScanNetpp +dust3r.datasets.StaticThings3D = StaticThings3D +dust3r.datasets.Waymo = Waymo +dust3r.datasets.WildRGBD = WildRGBD + +from dust3r.training import get_args_parser as dust3r_get_args_parser # noqa +from dust3r.training import train # noqa + + +def get_args_parser(): + parser = dust3r_get_args_parser() + # change defaults + parser.prog = 'MASt3R training' + parser.set_defaults(model="AsymmetricMASt3R(patch_embed_cls='ManyAR_PatchEmbed')") + return parser + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + train(args) diff --git a/imcui/third_party/mast3r/visloc.py b/imcui/third_party/mast3r/visloc.py new file mode 100644 index 0000000000000000000000000000000000000000..8e460169068d3d82308538aef6c14c756e9848ce --- /dev/null +++ b/imcui/third_party/mast3r/visloc.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# visloc script with support for coarse to fine +# -------------------------------------------------------- +import os +import numpy as np +import random +import torch +import torchvision.transforms as tvf +import argparse +from tqdm import tqdm +from PIL import Image +import math + +from mast3r.model import AsymmetricMASt3R +from mast3r.fast_nn import fast_reciprocal_NNs +from mast3r.utils.coarse_to_fine import select_pairs_of_crops, crop_slice +from mast3r.utils.collate import cat_collate, cat_collate_fn_map +from mast3r.utils.misc import mkdir_for +from mast3r.datasets.utils.cropping import crop_to_homography + +import mast3r.utils.path_to_dust3r # noqa +from dust3r.inference import inference, loss_of_one_batch +from dust3r.utils.geometry import geotrf, colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r_visloc.datasets import * +from dust3r_visloc.localization import run_pnp +from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results +from dust3r_visloc.datasets.utils import get_HW_resolution, rescale_points3d + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval") + parser_weights = parser.add_mutually_exclusive_group(required=True) + parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) + parser_weights.add_argument("--model_name", type=str, help="name of the model weights", + choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]) + + parser.add_argument("--confidence_threshold", type=float, default=1.001, + help="confidence values higher than threshold are invalid") + parser.add_argument('--pixel_tol', default=5, type=int) + + parser.add_argument("--coarse_to_fine", action='store_true', default=False, + help="do the matching from coarse to fine") + parser.add_argument("--max_image_size", type=int, default=None, + help="max image size for the fine resolution") + parser.add_argument("--c2f_crop_with_homography", action='store_true', default=False, + help="when using coarse to fine, crop with homographies to keep cx, cy centered") + + parser.add_argument("--device", type=str, default='cuda', help="pytorch device") + parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'], + help="pnp lib to use") + parser_reproj = parser.add_mutually_exclusive_group() + parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error") + parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None, + help="pnp reprojection error as a ratio of the diagonal of the image") + + parser.add_argument("--max_batch_size", type=int, default=48, + help="max batch size for inference on crops when using coarse to fine") + parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept") + parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") + + parser.add_argument("--output_dir", type=str, default=None, help="output path") + parser.add_argument("--output_label", type=str, default='', help="prefix for results files") + return parser + + +@torch.no_grad() +def coarse_matching(query_view, map_view, model, device, pixel_tol, fast_nn_params): + # prepare batch + imgs = [] + for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]): + imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]), + idx=idx, instance=str(idx))) + output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False) + pred1, pred2 = output['pred1'], output['pred2'] + conf_list = [pred1['desc_conf'].squeeze(0).cpu().numpy(), pred2['desc_conf'].squeeze(0).cpu().numpy()] + desc_list = [pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()] + + # find 2D-2D matches between the two images + PQ, PM = desc_list[0], desc_list[1] + if len(PQ) == 0 or len(PM) == 0: + return [], [], [], [] + + if pixel_tol == 0: + matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, subsample_or_initxy1=8, **fast_nn_params) + HM, WM = map_view['rgb_rescaled'].shape[1:] + HQ, WQ = query_view['rgb_rescaled'].shape[1:] + # ignore small border around the edge + valid_matches_map = (matches_im_map[:, 0] >= 3) & (matches_im_map[:, 0] < WM - 3) & ( + matches_im_map[:, 1] >= 3) & (matches_im_map[:, 1] < HM - 3) + valid_matches_query = (matches_im_query[:, 0] >= 3) & (matches_im_query[:, 0] < WQ - 3) & ( + matches_im_query[:, 1] >= 3) & (matches_im_query[:, 1] < HQ - 3) + valid_matches = valid_matches_map & valid_matches_query + matches_im_map = matches_im_map[valid_matches] + matches_im_query = matches_im_query[valid_matches] + valid_pts3d = [] + matches_confs = [] + else: + yM, xM = torch.where(map_view['valid_rescaled']) + matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, (xM, yM), pixel_tol=pixel_tol, **fast_nn_params) + valid_pts3d = map_view['pts3d_rescaled'].cpu().numpy()[matches_im_map[:, 1], matches_im_map[:, 0]] + matches_confs = np.minimum( + conf_list[1][matches_im_map[:, 1], matches_im_map[:, 0]], + conf_list[0][matches_im_query[:, 1], matches_im_query[:, 0]] + ) + # from cv2 to colmap + matches_im_query = matches_im_query.astype(np.float64) + matches_im_map = matches_im_map.astype(np.float64) + matches_im_query[:, 0] += 0.5 + matches_im_query[:, 1] += 0.5 + matches_im_map[:, 0] += 0.5 + matches_im_map[:, 1] += 0.5 + # rescale coordinates + matches_im_query = geotrf(query_view['to_orig'], matches_im_query, norm=True) + matches_im_map = geotrf(map_view['to_orig'], matches_im_map, norm=True) + # from colmap back to cv2 + matches_im_query[:, 0] -= 0.5 + matches_im_query[:, 1] -= 0.5 + matches_im_map[:, 0] -= 0.5 + matches_im_map[:, 1] -= 0.5 + return valid_pts3d, matches_im_query, matches_im_map, matches_confs + + +@torch.no_grad() +def crops_inference(pairs, model, device, batch_size=48, verbose=True): + assert len(pairs) == 2, "Error, data should be a tuple of dicts containing the batch of image pairs" + # Forward a possibly big bunch of data, by blocks of batch_size + B = pairs[0]['img'].shape[0] + if B < batch_size: + return loss_of_one_batch(pairs, model, None, device=device, symmetrize_batch=False) + preds = [] + for ii in range(0, B, batch_size): + sel = slice(ii, ii + min(B - ii, batch_size)) + temp_data = [{}, {}] + for di in [0, 1]: + temp_data[di] = {kk: pairs[di][kk][sel] + for kk in pairs[di].keys() if pairs[di][kk] is not None} # copy chunk for forward + preds.append(loss_of_one_batch(temp_data, model, + None, device=device, symmetrize_batch=False)) # sequential forward + # Merge all preds + return cat_collate(preds, collate_fn_map=cat_collate_fn_map) + + +@torch.no_grad() +def fine_matching(query_views, map_views, model, device, max_batch_size, pixel_tol, fast_nn_params): + assert pixel_tol > 0 + output = crops_inference([query_views, map_views], + model, device, batch_size=max_batch_size, verbose=False) + pred1, pred2 = output['pred1'], output['pred2'] + descs1 = pred1['desc'].clone() + descs2 = pred2['desc'].clone() + confs1 = pred1['desc_conf'].clone() + confs2 = pred2['desc_conf'].clone() + + # Compute matches + valid_pts3d, matches_im_map, matches_im_query, matches_confs = [], [], [], [] + for ppi, (pp1, pp2, cc11, cc21) in enumerate(zip(descs1, descs2, confs1, confs2)): + valid_ppi = map_views['valid'][ppi] + pts3d_ppi = map_views['pts3d'][ppi].cpu().numpy() + conf_list_ppi = [cc11.cpu().numpy(), cc21.cpu().numpy()] + + y_ppi, x_ppi = torch.where(valid_ppi) + matches_im_map_ppi, matches_im_query_ppi = fast_reciprocal_NNs(pp2, pp1, (x_ppi, y_ppi), + pixel_tol=pixel_tol, **fast_nn_params) + + valid_pts3d_ppi = pts3d_ppi[matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]] + matches_confs_ppi = np.minimum( + conf_list_ppi[1][matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]], + conf_list_ppi[0][matches_im_query_ppi[:, 1], matches_im_query_ppi[:, 0]] + ) + # inverse operation where we uncrop pixel coordinates + matches_im_map_ppi = geotrf(map_views['to_orig'][ppi].cpu().numpy(), matches_im_map_ppi.copy(), norm=True) + matches_im_query_ppi = geotrf(query_views['to_orig'][ppi].cpu().numpy(), matches_im_query_ppi.copy(), norm=True) + + matches_im_map.append(matches_im_map_ppi) + matches_im_query.append(matches_im_query_ppi) + valid_pts3d.append(valid_pts3d_ppi) + matches_confs.append(matches_confs_ppi) + + if len(valid_pts3d) == 0: + return [], [], [], [] + + matches_im_map = np.concatenate(matches_im_map, axis=0) + matches_im_query = np.concatenate(matches_im_query, axis=0) + valid_pts3d = np.concatenate(valid_pts3d, axis=0) + matches_confs = np.concatenate(matches_confs, axis=0) + return valid_pts3d, matches_im_query, matches_im_map, matches_confs + + +def crop(img, mask, pts3d, crop, intrinsics=None): + out_cropped_img = img.clone() + if mask is not None: + out_cropped_mask = mask.clone() + else: + out_cropped_mask = None + if pts3d is not None: + out_cropped_pts3d = pts3d.clone() + else: + out_cropped_pts3d = None + to_orig = torch.eye(3, device=img.device) + + # If intrinsics available, crop and apply rectifying homography. Otherwise, just crop + if intrinsics is not None: + K_old = intrinsics + imsize, K_new, R, H = crop_to_homography(K_old, crop) + # apply homography to image + H /= H[2, 2] + homo8 = H.ravel().tolist()[:8] + # From float tensor to uint8 PIL Image + pilim = Image.fromarray((255 * (img + 1.) / 2).to(torch.uint8).numpy()) + pilout_cropped_img = pilim.transform(imsize, Image.Transform.PERSPECTIVE, + homo8, resample=Image.Resampling.BICUBIC) + + # From uint8 PIL Image to float tensor + out_cropped_img = 2. * torch.tensor(np.array(pilout_cropped_img)).to(img) / 255. - 1. + if out_cropped_mask is not None: + pilmask = Image.fromarray((255 * out_cropped_mask).to(torch.uint8).numpy()) + pilout_cropped_mask = pilmask.transform( + imsize, Image.Transform.PERSPECTIVE, homo8, resample=Image.Resampling.NEAREST) + out_cropped_mask = torch.from_numpy(np.array(pilout_cropped_mask) > 0).to(out_cropped_mask.dtype) + if out_cropped_pts3d is not None: + out_cropped_pts3d = out_cropped_pts3d.numpy() + out_cropped_X = np.array(Image.fromarray(out_cropped_pts3d[:, :, 0]).transform(imsize, + Image.Transform.PERSPECTIVE, + homo8, + resample=Image.Resampling.NEAREST)) + out_cropped_Y = np.array(Image.fromarray(out_cropped_pts3d[:, :, 1]).transform(imsize, + Image.Transform.PERSPECTIVE, + homo8, + resample=Image.Resampling.NEAREST)) + out_cropped_Z = np.array(Image.fromarray(out_cropped_pts3d[:, :, 2]).transform(imsize, + Image.Transform.PERSPECTIVE, + homo8, + resample=Image.Resampling.NEAREST)) + + out_cropped_pts3d = torch.from_numpy(np.stack([out_cropped_X, out_cropped_Y, out_cropped_Z], axis=-1)) + + to_orig = torch.tensor(H, device=img.device) + else: + out_cropped_img = img[crop_slice(crop)] + if out_cropped_mask is not None: + out_cropped_mask = out_cropped_mask[crop_slice(crop)] + if out_cropped_pts3d is not None: + out_cropped_pts3d = out_cropped_pts3d[crop_slice(crop)] + to_orig[:2, -1] = torch.tensor(crop[:2]) + + return out_cropped_img, out_cropped_mask, out_cropped_pts3d, to_orig + + +def resize_image_to_max(max_image_size, rgb, K): + W, H = rgb.size + if max_image_size and max(W, H) > max_image_size: + islandscape = (W >= H) + if islandscape: + WMax = max_image_size + HMax = int(H * (WMax / W)) + else: + HMax = max_image_size + WMax = int(W * (HMax / H)) + resize_op = tvf.Compose([ImgNorm, tvf.Resize(size=[HMax, WMax])]) + rgb_tensor = resize_op(rgb).permute(1, 2, 0) + to_orig_max = np.array([[W / WMax, 0, 0], + [0, H / HMax, 0], + [0, 0, 1]]) + to_resize_max = np.array([[WMax / W, 0, 0], + [0, HMax / H, 0], + [0, 0, 1]]) + + # Generate new camera parameters + new_K = opencv_to_colmap_intrinsics(K) + new_K[0, :] *= WMax / W + new_K[1, :] *= HMax / H + new_K = colmap_to_opencv_intrinsics(new_K) + else: + rgb_tensor = ImgNorm(rgb).permute(1, 2, 0) + to_orig_max = np.eye(3) + to_resize_max = np.eye(3) + HMax, WMax = H, W + new_K = K + return rgb_tensor, new_K, to_orig_max, to_resize_max, (HMax, WMax) + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + conf_thr = args.confidence_threshold + device = args.device + pnp_mode = args.pnp_mode + assert args.pixel_tol > 0 + reprojection_error = args.reprojection_error + reprojection_error_diag_ratio = args.reprojection_error_diag_ratio + pnp_max_points = args.pnp_max_points + viz_matches = args.viz_matches + + if args.weights is not None: + weights_path = args.weights + else: + weights_path = "naver/" + args.model_name + model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) + fast_nn_params = dict(device=device, dist='dot', block_size=2**13) + dataset = eval(args.dataset) + dataset.set_resolution(model) + + query_names = [] + poses_pred = [] + pose_errors = [] + angular_errors = [] + params_str = f'tol_{args.pixel_tol}' + ("_c2f" if args.coarse_to_fine else '') + if args.max_image_size is not None: + params_str = params_str + f'_{args.max_image_size}' + if args.coarse_to_fine and args.c2f_crop_with_homography: + params_str = params_str + '_with_homography' + for idx in tqdm(range(len(dataset))): + views = dataset[(idx)] # 0 is the query + query_view = views[0] + map_views = views[1:] + query_names.append(query_view['image_name']) + + query_pts2d = [] + query_pts3d = [] + maxdim = max(model.patch_embed.img_size) + query_rgb_tensor, query_K, query_to_orig_max, query_to_resize_max, (HQ, WQ) = resize_image_to_max( + args.max_image_size, query_view['rgb'], query_view['intrinsics']) + + # pairs of crops have the same resolution + query_resolution = get_HW_resolution(HQ, WQ, maxdim=maxdim, patchsize=model.patch_embed.patch_size) + for map_view in map_views: + if args.output_dir is not None: + cache_file = os.path.join(args.output_dir, 'matches', params_str, + query_view['image_name'], map_view['image_name'] + '.npz') + else: + cache_file = None + + if cache_file is not None and os.path.isfile(cache_file): + matches = np.load(cache_file) + valid_pts3d = matches['valid_pts3d'] + matches_im_query = matches['matches_im_query'] + matches_im_map = matches['matches_im_map'] + matches_conf = matches['matches_conf'] + else: + # coarse matching + if args.coarse_to_fine and (maxdim < max(WQ, HQ)): + # use all points + _, coarse_matches_im0, coarse_matches_im1, _ = coarse_matching(query_view, map_view, model, device, + 0, fast_nn_params) + + # visualize a few matches + if viz_matches > 0: + num_matches = coarse_matches_im1.shape[0] + print(f'found {num_matches} matches') + + viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] + from matplotlib import pyplot as pl + n_viz = viz_matches + match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) + viz_matches_im_query = coarse_matches_im0[match_idx_to_viz] + viz_matches_im_map = coarse_matches_im1[match_idx_to_viz] + + H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] + img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), + 'constant', constant_values=0) + img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), + 'constant', constant_values=0) + img = np.concatenate((img0, img1), axis=1) + pl.figure() + pl.imshow(img) + cmap = pl.get_cmap('jet') + for i in range(n_viz): + (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T + pl.plot([x0, x1 + W0], [y0, y1], '-+', + color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) + pl.show(block=True) + + valid_all = map_view['valid'] + pts3d = map_view['pts3d'] + + WM_full, HM_full = map_view['rgb'].size + map_rgb_tensor, map_K, map_to_orig_max, map_to_resize_max, (HM, WM) = resize_image_to_max( + args.max_image_size, map_view['rgb'], map_view['intrinsics']) + if WM_full != WM or HM_full != HM: + y_full, x_full = torch.where(valid_all) + pos2d_cv2 = torch.stack([x_full, y_full], dim=-1).cpu().numpy().astype(np.float64) + sparse_pts3d = pts3d[y_full, x_full].cpu().numpy() + _, _, pts3d_max, valid_max = rescale_points3d( + pos2d_cv2, sparse_pts3d, map_to_resize_max, HM, WM) + pts3d = torch.from_numpy(pts3d_max) + valid_all = torch.from_numpy(valid_max) + + coarse_matches_im0 = geotrf(query_to_resize_max, coarse_matches_im0, norm=True) + coarse_matches_im1 = geotrf(map_to_resize_max, coarse_matches_im1, norm=True) + + crops1, crops2 = [], [] + crops_v1, crops_p1 = [], [] + to_orig1, to_orig2 = [], [] + map_resolution = get_HW_resolution(HM, WM, maxdim=maxdim, patchsize=model.patch_embed.patch_size) + + for crop_q, crop_b, pair_tag in select_pairs_of_crops(map_rgb_tensor, + query_rgb_tensor, + coarse_matches_im1, + coarse_matches_im0, + maxdim=maxdim, + overlap=.5, + forced_resolution=[map_resolution, + query_resolution]): + # Per crop processing + if not args.c2f_crop_with_homography: + map_K = None + query_K = None + + c1, v1, p1, trf1 = crop(map_rgb_tensor, valid_all, pts3d, crop_q, map_K) + c2, _, _, trf2 = crop(query_rgb_tensor, None, None, crop_b, query_K) + crops1.append(c1) + crops2.append(c2) + crops_v1.append(v1) + crops_p1.append(p1) + to_orig1.append(trf1) + to_orig2.append(trf2) + + if len(crops1) == 0 or len(crops2) == 0: + valid_pts3d, matches_im_query, matches_im_map, matches_conf = [], [], [], [] + else: + crops1, crops2 = torch.stack(crops1), torch.stack(crops2) + if len(crops1.shape) == 3: + crops1, crops2 = crops1[None], crops2[None] + crops_v1 = torch.stack(crops_v1) + crops_p1 = torch.stack(crops_p1) + to_orig1, to_orig2 = torch.stack(to_orig1), torch.stack(to_orig2) + map_crop_view = dict(img=crops1.permute(0, 3, 1, 2), + instance=['1' for _ in range(crops1.shape[0])], + valid=crops_v1, pts3d=crops_p1, + to_orig=to_orig1) + query_crop_view = dict(img=crops2.permute(0, 3, 1, 2), + instance=['2' for _ in range(crops2.shape[0])], + to_orig=to_orig2) + + # Inference and Matching + valid_pts3d, matches_im_query, matches_im_map, matches_conf = fine_matching(query_crop_view, + map_crop_view, + model, device, + args.max_batch_size, + args.pixel_tol, + fast_nn_params) + matches_im_query = geotrf(query_to_orig_max, matches_im_query, norm=True) + matches_im_map = geotrf(map_to_orig_max, matches_im_map, norm=True) + else: + # use only valid 2d points + valid_pts3d, matches_im_query, matches_im_map, matches_conf = coarse_matching(query_view, map_view, + model, device, + args.pixel_tol, + fast_nn_params) + if cache_file is not None: + mkdir_for(cache_file) + np.savez(cache_file, valid_pts3d=valid_pts3d, matches_im_query=matches_im_query, + matches_im_map=matches_im_map, matches_conf=matches_conf) + + # apply conf + if len(matches_conf) > 0: + mask = matches_conf >= conf_thr + valid_pts3d = valid_pts3d[mask] + matches_im_query = matches_im_query[mask] + matches_im_map = matches_im_map[mask] + matches_conf = matches_conf[mask] + + # visualize a few matches + if viz_matches > 0: + num_matches = matches_im_map.shape[0] + print(f'found {num_matches} matches') + + viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] + from matplotlib import pyplot as pl + n_viz = viz_matches + match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) + viz_matches_im_query = matches_im_query[match_idx_to_viz] + viz_matches_im_map = matches_im_map[match_idx_to_viz] + + H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] + img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) + img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) + img = np.concatenate((img0, img1), axis=1) + pl.figure() + pl.imshow(img) + cmap = pl.get_cmap('jet') + for i in range(n_viz): + (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T + pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) + pl.show(block=True) + + if len(valid_pts3d) == 0: + pass + else: + query_pts3d.append(valid_pts3d) + query_pts2d.append(matches_im_query) + + if len(query_pts2d) == 0: + success = False + pr_querycam_to_world = None + else: + query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) + query_pts3d = np.concatenate(query_pts3d, axis=0) + if len(query_pts2d) > pnp_max_points: + idxs = random.sample(range(len(query_pts2d)), pnp_max_points) + query_pts3d = query_pts3d[idxs] + query_pts2d = query_pts2d[idxs] + + W, H = query_view['rgb'].size + if reprojection_error_diag_ratio is not None: + reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2) + else: + reprojection_error_img = reprojection_error + success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d, + query_view['intrinsics'], query_view['distortion'], + pnp_mode, reprojection_error_img, img_size=[W, H]) + + if not success: + abs_transl_error = float('inf') + abs_angular_error = float('inf') + else: + abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world']) + + pose_errors.append(abs_transl_error) + angular_errors.append(abs_angular_error) + poses_pred.append(pr_querycam_to_world) + + xp_label = params_str + f'_conf_{conf_thr}' + if args.output_label: + xp_label = args.output_label + "_" + xp_label + if reprojection_error_diag_ratio is not None: + xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}' + else: + xp_label = xp_label + f'_reproj_err_{reprojection_error}' + export_results(args.output_dir, xp_label, query_names, poses_pred) + out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors) + print(out_string) diff --git a/imcui/third_party/mickey/__init__.py b/imcui/third_party/mickey/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/mickey/benchmark/__init__.py b/imcui/third_party/mickey/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/mickey/benchmark/config.py b/imcui/third_party/mickey/benchmark/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f7845f24e6d41ffa0ccb494acc3234d38a3217 --- /dev/null +++ b/imcui/third_party/mickey/benchmark/config.py @@ -0,0 +1,8 @@ +# translation and rotation thresholds [meters, degrees] +# used to compute Precision and AUC considering Pose Error +t_threshold = 0.25 +R_threshold = 5 + +# reprojection (VCRE) threshold [pixels] +# used to compute Precision and AUC considering VCRE +vcre_threshold = 90 diff --git a/imcui/third_party/mickey/benchmark/mapfree.py b/imcui/third_party/mickey/benchmark/mapfree.py new file mode 100644 index 0000000000000000000000000000000000000000..6039e537b151723e7376a32868296b51083cf9dc --- /dev/null +++ b/imcui/third_party/mickey/benchmark/mapfree.py @@ -0,0 +1,198 @@ +import argparse +from collections import defaultdict +from pathlib import Path +from zipfile import ZipFile +from io import TextIOWrapper +import json +import logging +import numpy as np + +from benchmark.utils import load_poses, subsample_poses, load_K, precision_recall +from benchmark.metrics import MetricManager, Inputs +import benchmark.config as config +from config.default import cfg + +def plot_perfect_curve(P): + total_bins = 1000 + prec_values = [] + ratio_values = [] + for i in range(total_bins): + ratio_tmp = i/total_bins + value = min(1, P / ratio_tmp) + prec_values.append(value) + ratio_values.append(ratio_tmp) + return prec_values, ratio_values + +def compute_scene_metrics(dataset_path: Path, submission_zip: ZipFile, scene: str): + metric_manager = MetricManager() + + # load intrinsics and poses + try: + K, W, H = load_K(dataset_path / scene / 'intrinsics.txt') + with (dataset_path / scene / 'poses.txt').open('r', encoding='utf-8') as gt_poses_file: + gt_poses = load_poses(gt_poses_file, load_confidence=False) + except FileNotFoundError as e: + logging.error(f'Could not find ground-truth dataset files: {e}') + raise + else: + logging.info( + f'Loaded ground-truth intrinsics and poses for scene {scene}') + + # try to load estimated poses from submission + try: + with submission_zip.open(f'pose_{scene}.txt') as estimated_poses_file: + estimated_poses_file_wrapper = TextIOWrapper( + estimated_poses_file, encoding='utf-8') + estimated_poses = load_poses( + estimated_poses_file_wrapper, load_confidence=True) + except KeyError as e: + logging.warning( + f'Submission does not have estimates for scene {scene}.') + return dict(), len(gt_poses) + except UnicodeDecodeError as e: + logging.error('Unsupported file encoding: please use UTF-8') + raise + else: + logging.info(f'Loaded estimated poses for scene {scene}') + + # The val/test set is subsampled by a factor of 5 + gt_poses = subsample_poses(gt_poses, subsample=5) + + # failures encode how many frames did not have an estimate + # e.g. user/method did not provide an estimate for that frame + # it's different from when an estimate is provided with low confidence! + failures = 0 + + # Results encoded as dict + # key: metric name; value: list of values (one per frame). + # e.g. results['t_err'] = [1.2, 0.3, 0.5, ...] + results = defaultdict(list) + + # compute metrics per frame + for frame_num, (q_gt, t_gt, _) in gt_poses.items(): + if frame_num not in estimated_poses: + failures += 1 + continue + + q_est, t_est, confidence = estimated_poses[frame_num] + inputs = Inputs(q_gt=q_gt, t_gt=t_gt, q_est=q_est, t_est=t_est, + confidence=confidence, K=K[frame_num], W=W, H=H) + metric_manager(inputs, results) + + return results, failures + + +def aggregate_results(all_results, all_failures): + # aggregate metrics + median_metrics = defaultdict(list) + all_metrics = defaultdict(list) + for scene_results in all_results.values(): + for metric, values in scene_results.items(): + median_metrics[metric].append(np.median(values)) + all_metrics[metric].extend(values) + all_metrics = {k: np.array(v) for k, v in all_metrics.items()} + assert all([v.ndim == 1 for v in all_metrics.values()] + ), 'invalid metrics shape' + + # compute avg median metrics + avg_median_metrics = {metric: np.mean( + values) for metric, values in median_metrics.items()} + + # compute precision/AUC for pose error and reprojection errors + accepted_poses = (all_metrics['trans_err'] < config.t_threshold) * \ + (all_metrics['rot_err'] < config.R_threshold) + accepted_vcre = all_metrics['reproj_err'] < config.vcre_threshold + total_samples = len(next(iter(all_metrics.values()))) + all_failures + + prec_pose = np.sum(accepted_poses) / total_samples + prec_vcre = np.sum(accepted_vcre) / total_samples + + # compute AUC for pose and VCRE + pose_prec_values, pose_recall_values, auc_pose = precision_recall( + inliers=all_metrics['confidence'], tp=accepted_poses, failures=all_failures) + vcre_prec_values, vcre_recall_values, auc_vcre = precision_recall( + inliers=all_metrics['confidence'], tp=accepted_vcre, failures=all_failures) + + curves_data = {} + curves_data['vcre_prec_values'], curves_data['vcre_recall_values'] = vcre_prec_values, vcre_recall_values + curves_data['pose_prec_values'], curves_data['pose_recall_values'] = pose_prec_values, pose_recall_values + + # output metrics + output_metrics = dict() + output_metrics['Average Median Translation Error'] = avg_median_metrics['trans_err'] + output_metrics['Average Median Rotation Error'] = avg_median_metrics['rot_err'] + output_metrics['Average Median Reprojection Error'] = avg_median_metrics['reproj_err'] + output_metrics[f'Precision @ Pose Error < ({config.t_threshold*100}cm, {config.R_threshold}deg)'] = prec_pose + output_metrics[f'AUC @ Pose Error < ({config.t_threshold*100}cm, {config.R_threshold}deg)'] = auc_pose + output_metrics[f'Precision @ VCRE < {config.vcre_threshold}px'] = prec_vcre + output_metrics[f'AUC @ VCRE < {config.vcre_threshold}px'] = auc_vcre + output_metrics[f'Estimates for % of frames'] = len(all_metrics['trans_err']) / total_samples + return output_metrics, curves_data + + +def count_unexpected_scenes(scenes: tuple, submission_zip: ZipFile): + submission_scenes = [fname[5:-4] + for fname in submission_zip.namelist() if fname.startswith("pose_")] + return len(set(submission_scenes) - set(scenes)) + +def main(args): + dataset_path = args.dataset_path / args.split + scenes = tuple(f.name for f in dataset_path.iterdir() if f.is_dir()) + + try: + submission_zip = ZipFile(args.submission_path, 'r') + except FileNotFoundError as e: + logging.error(f'Could not find ZIP file in path {args.submission_path}') + return + + all_results = dict() + all_failures = 0 + for scene in scenes: + metrics, failures = compute_scene_metrics( + dataset_path, submission_zip, scene) + all_results[scene] = metrics + all_failures += failures + + if all_failures > 0: + logging.warning( + f'Submission is missing pose estimates for {all_failures} frames') + + unexpected_scene_count = count_unexpected_scenes(scenes, submission_zip) + if unexpected_scene_count > 0: + logging.warning( + f'Submission contains estimates for {unexpected_scene_count} scenes outside the {args.split} set') + + if all((len(metrics) == 0 for metrics in all_results.values())): + logging.error( + f'Submission does not have any valid pose estimates') + return + + output_metrics, curves_data = aggregate_results(all_results, all_failures) + output_json = json.dumps(output_metrics, indent=2) + print(output_json) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + 'eval', description='Evaluate submissions for the MapFree dataset benchmark') + parser.add_argument('--submission_path', type=Path, default='', + help='Path to the submission ZIP file') + parser.add_argument('--split', choices=('val', 'test'), default='test', + help='Dataset split to use for evaluation. Default: test') + parser.add_argument('--log', choices=('warning', 'info', 'error'), + default='warning', help='Logging level. Default: warning') + parser.add_argument('--dataset_path', type=Path, default=None, + help='Path to the dataset folder') + + args = parser.parse_args() + + if args.dataset_path is None: + cfg.merge_from_file('config/datasets/mapfree.yaml') + args.dataset_path = Path(cfg.DATASET.DATA_ROOT) + + logging.basicConfig(level=args.log.upper()) + try: + main(args) + except Exception: + logging.error("Unexpected behaviour. Exiting.") + diff --git a/imcui/third_party/mickey/benchmark/metrics.py b/imcui/third_party/mickey/benchmark/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..99fb1cf271bc7f35809e50f9a28a3966340ce998 --- /dev/null +++ b/imcui/third_party/mickey/benchmark/metrics.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from typing import Callable + +import numpy as np + +from benchmark.reprojection import reprojection_error +from benchmark.utils import VARIANTS_ANGLE_SIN, quat_angle_error + + +@dataclass +class Inputs: + q_gt: np.array + t_gt: np.array + q_est: np.array + t_est: np.array + confidence: float + K: np.array + W: int + H: int + + def __post_init__(self): + assert self.q_gt.shape == (4,), 'invalid gt quaternion shape' + assert self.t_gt.shape == (3,), 'invalid gt translation shape' + assert self.q_est.shape == (4,), 'invalid estimated quaternion shape' + assert self.t_est.shape == (3,), 'invalid estimated translation shape' + assert self.confidence >= 0, 'confidence must be non negative' + assert self.K.shape == (3, 3), 'invalid K shape' + assert self.W > 0, 'invalid image width' + assert self.H > 0, 'invalid image height' + + +class MyDict(dict): + def register(self, fn) -> Callable: + """Registers a function within dict(fn_name -> fn_ref). + This is used to evaluate all registered metrics in MetricManager.__call__()""" + self[fn.__name__] = fn + return fn + + +class MetricManager: + _metrics = MyDict() + + def __call__(self, inputs: Inputs, results: dict) -> None: + for metric, metric_fn in self._metrics.items(): + results[metric].append(metric_fn(inputs)) + + @staticmethod + @_metrics.register + def trans_err(inputs: Inputs) -> np.float64: + return np.linalg.norm(inputs.t_est - inputs.t_gt) + + @staticmethod + @_metrics.register + def rot_err(inputs: Inputs, variant: str = VARIANTS_ANGLE_SIN) -> np.float64: + return quat_angle_error(label=inputs.q_est, pred=inputs.q_gt, variant=variant)[0, 0] + + @staticmethod + @_metrics.register + def reproj_err(inputs: Inputs) -> float: + return reprojection_error( + q_est=inputs.q_est, t_est=inputs.t_est, q_gt=inputs.q_gt, t_gt=inputs.t_gt, K=inputs.K, + W=inputs.W, H=inputs.H) + + @staticmethod + @_metrics.register + def confidence(inputs: Inputs) -> float: + return inputs.confidence diff --git a/imcui/third_party/mickey/benchmark/reprojection.py b/imcui/third_party/mickey/benchmark/reprojection.py new file mode 100644 index 0000000000000000000000000000000000000000..ebff993ed0d45379a838045a6fa916006751b5e2 --- /dev/null +++ b/imcui/third_party/mickey/benchmark/reprojection.py @@ -0,0 +1,86 @@ +from typing import List, Tuple + +import numpy as np +from transforms3d.quaternions import quat2mat + + +def project(pts: np.ndarray, K: np.ndarray, img_size: List[int] or Tuple[int] = None) -> np.ndarray: + """Projects 3D points to image plane. + + Args: + - pts [N, 3/4]: points in camera coordinates (homogeneous or non-homogeneous) + - K [3, 3]: intrinsic matrix + - img_size (width, height): optional, clamp projection to image borders + Outputs: + - uv [N, 2]: coordinates of projected points + """ + + assert len(pts.shape) == 2, 'incorrect number of dimensions' + assert pts.shape[1] in [3, 4], 'invalid dimension size' + assert K.shape == (3, 3), 'incorrect intrinsic shape' + + uv_h = (K @ pts[:, :3].T).T + uv = uv_h[:, :2] / uv_h[:, -1:] + + if img_size is not None: + uv[:, 0] = np.clip(uv[:, 0], 0, img_size[0]) + uv[:, 1] = np.clip(uv[:, 1], 0, img_size[1]) + + return uv + + +def get_grid_multipleheight() -> np.ndarray: + # create grid of points + ar_grid_step = 0.3 + ar_grid_num_x = 7 + ar_grid_num_y = 4 + ar_grid_num_z = 7 + ar_grid_z_offset = 1.8 + ar_grid_y_offset = 0 + + ar_grid_x_pos = np.arange(0, ar_grid_num_x)-(ar_grid_num_x-1)/2 + ar_grid_x_pos *= ar_grid_step + + ar_grid_y_pos = np.arange(0, ar_grid_num_y)-(ar_grid_num_y-1)/2 + ar_grid_y_pos *= ar_grid_step + ar_grid_y_pos += ar_grid_y_offset + + ar_grid_z_pos = np.arange(0, ar_grid_num_z).astype(float) + ar_grid_z_pos *= ar_grid_step + ar_grid_z_pos += ar_grid_z_offset + + xx, yy, zz = np.meshgrid(ar_grid_x_pos, ar_grid_y_pos, ar_grid_z_pos) + ones = np.ones(xx.shape[0]*xx.shape[1]*xx.shape[2]) + eye_coords = np.concatenate([c.reshape(-1, 1) + for c in (xx, yy, zz, ones)], axis=-1) + return eye_coords + + +# global variable, avoids creating it again +eye_coords_glob = get_grid_multipleheight() + + +def reprojection_error( + q_est: np.ndarray, t_est: np.ndarray, q_gt: np.ndarray, t_gt: np.ndarray, K: np.ndarray, + W: int, H: int) -> float: + eye_coords = eye_coords_glob + + # obtain ground-truth position of projected points + uv_gt = project(eye_coords, K, (W, H)) + + # residual transformation + cam2w_est = np.eye(4) + cam2w_est[:3, :3] = quat2mat(q_est) + cam2w_est[:3, -1] = t_est + cam2w_gt = np.eye(4) + cam2w_gt[:3, :3] = quat2mat(q_gt) + cam2w_gt[:3, -1] = t_gt + + # residual reprojection + eyes_residual = (np.linalg.inv(cam2w_est) @ cam2w_gt @ eye_coords.T).T + uv_pred = project(eyes_residual, K, (W, H)) + + # get reprojection error + repr_err = np.linalg.norm(uv_gt - uv_pred, ord=2, axis=1) + mean_repr_err = float(repr_err.mean().item()) + return mean_repr_err diff --git a/imcui/third_party/mickey/benchmark/test_metrics.py b/imcui/third_party/mickey/benchmark/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ad37da787ad1841679fbf152a4d5740c0233dc --- /dev/null +++ b/imcui/third_party/mickey/benchmark/test_metrics.py @@ -0,0 +1,174 @@ +import numpy as np +import pytest +from transforms3d.euler import euler2quat +from transforms3d.quaternions import axangle2quat, qmult, quat2mat, rotate_vector + +from benchmark.metrics import Inputs, MetricManager +from benchmark.reprojection import project +from benchmark.utils import VARIANTS_ANGLE_COS, VARIANTS_ANGLE_SIN + + +def createInput(q_gt=None, t_gt=None, q_est=None, t_est=None, confidence=None, K=None, W=None, H=None): + q_gt = np.zeros(4) if q_gt is None else q_gt + t_gt = np.zeros(3) if t_gt is None else t_gt + q_est = np.zeros(4) if q_est is None else q_est + t_est = np.zeros(3) if t_est is None else t_est + confidence = 0. if confidence is None else confidence + K = np.eye(3) if K is None else K + H = 1 if H is None else H + W = 1 if W is None else W + return Inputs(q_gt=q_gt, t_gt=t_gt, q_est=q_est, t_est=t_est, confidence=confidence, K=K, W=W, H=H) + + +def randomQuat(): + angles = np.random.uniform(0, 2*np.pi, 3) + q = euler2quat(*angles) + return q + + +class TestMetrics: + @pytest.mark.parametrize('run_number', range(50)) + def test_t_err_tinvariance(self, run_number: int) -> None: + """Computes the translation error given an initial translation and displacement of this + translation. The translation error must be equal to the norm of the displacement.""" + mean, var = 5, 10 + t0 = np.random.normal(mean, var, (3,)) + displacement = np.random.normal(mean, var, (3,)) + + i = createInput(t_gt=t0, t_est=t0+displacement) + trans_err = MetricManager.trans_err(i) + assert np.isclose(trans_err, np.linalg.norm(displacement)) + + @pytest.mark.parametrize('run_number', range(50)) + def test_trans_err_rinvariance(self, run_number: int) -> None: + """Computes the translation error given estimated and gt vectors. + The translation error must be the same for a rotated version of those vectors + (same random rotation)""" + mean, var = 5, 10 + t0 = np.random.normal(mean, var, (3,)) + t1 = np.random.normal(mean, var, (3,)) + q = randomQuat() + + i = createInput(t_gt=t0, t_est=t1) + trans_err = MetricManager.trans_err(i) + + ir = createInput(t_gt=rotate_vector(t0, q), t_est=rotate_vector(t1, q)) + trans_err_r = MetricManager.trans_err(ir) + + assert np.isclose(trans_err, trans_err_r) + + @pytest.mark.parametrize('run_number', range(50)) + @pytest.mark.parametrize('dtype', (np.float64, np.float32)) + def test_rot_err_raxis(self, run_number: int, dtype: type) -> None: + """Test rotation error for rotations around a random axis. + + Note: We create GT as high precision, and only downcast when calling rot_err. + """ + q = randomQuat().astype(np.float64) + + axis = np.random.uniform(low=-1, high=1, size=3).astype(np.float64) + angle = np.float64(np.random.uniform(low=-np.pi, high=np.pi)) + qres = axangle2quat(vector=axis, theta=angle, is_normalized=False).astype(np.float64) + + i = createInput(q_gt=q.astype(dtype), q_est=qmult(q, qres).astype(dtype)) + rot_err = MetricManager.rot_err(i) + assert isinstance(rot_err, np.float64) + rot_err_expected = np.abs(np.degrees(angle)) + # if we add up errors, we want them to be positive + assert 0. <= rot_err + rtol = 1.e-5 # numpy default + atol = 1.e-8 # numpy default + if isinstance(dtype, np.float32): + atol = 1.e-7 # 1/50 test might fail at 1.e-8 + assert np.isclose(rot_err, rot_err_expected, rtol=rtol, atol=atol) + + @pytest.mark.parametrize('run_number', range(50)) + def test_r_err_mat(self, run_number: int) -> None: + q0 = randomQuat() + q1 = randomQuat() + + i = createInput(q_gt=q0, q_est=q1) + rot_err = MetricManager.rot_err(i) + + R0 = quat2mat(q0) + R1 = quat2mat(q1) + Rres = R1 @ R0.T + theta = (np.trace(Rres) - 1)/2 + theta = np.clip(theta, -1, 1) + angle = np.degrees(np.arccos(theta)) + + assert np.isclose(angle, rot_err) + + def test_reproj_error_identity(self): + """Test that reprojection error is zero if poses match""" + q = randomQuat() + t = np.random.normal(0, 10, (3,)) + i = createInput(q_gt=q, t_gt=t, q_est=q, t_est=t) + + reproj_err = MetricManager.reproj_err(i) + assert np.isclose(reproj_err, 0) + + @pytest.mark.parametrize('run_number', range(10)) + @pytest.mark.parametrize('variant', (VARIANTS_ANGLE_SIN,)) + @pytest.mark.parametrize('dtype', (np.float64,)) + def test_r_err_small(self, run_number: int, variant: str, dtype: type) -> None: + """Test rotation error for small angle differences. + + Note: We create GT as high precision, and only downcast when calling rot_err. + """ + scales_failed = [] + for scale in np.logspace(start=-1, stop=-9, num=9, base=10, dtype=dtype): + q = randomQuat().astype(np.float64) + angle = np.float64(np.random.uniform(low=-np.pi, high=np.pi)) * scale + assert isinstance(angle, np.float64) + axis = np.random.uniform(low=-1., high=1., size=3).astype(np.float64) + assert axis.dtype == np.float64 + qres = axangle2quat(vector=axis, theta=angle, is_normalized=False).astype(np.float64) + assert qres.dtype == np.float64 + + i = createInput(q_gt=q.astype(dtype), q_est=qmult(q, qres).astype(dtype)) + + # We expect the error to always be np.float64 for highest acc. + rot_err = MetricManager.rot_err(i, variant=variant) + assert isinstance(rot_err, np.float64) + rot_err_expected = np.abs(np.degrees(angle)) + assert isinstance(rot_err_expected, type(rot_err)) + + # if we add up errors, we want them to be positive + assert 0. <= rot_err + + # check accuracy for one magnitude higher tolerance than the angle + tol = 0.1 * scale + # need to be more permissive for lower precision + if dtype == np.float32: + tol = 1.e3 * scale + + # cast to dtype for checking + rot_err = rot_err.astype(dtype) + rot_err_expected = rot_err_expected.astype(dtype) + + if variant == VARIANTS_ANGLE_SIN: + assert np.isclose(rot_err, rot_err_expected, rtol=tol, atol=tol) + elif variant == VARIANTS_ANGLE_COS: + if not np.isclose(rot_err, rot_err_expected, rtol=tol, atol=tol): + print(f"[variant '{variant}'] raises an error for\n" + f"\trot_err: {rot_err}" + f"\trot_err_expected: {rot_err_expected}" + f"\trtol: {tol}" + f"\tatol: {tol}") + scales_failed.append(scale) + if len(scales_failed): + pytest.fail(f"Variant {variant} failed at scales {scales_failed}") + + +def test_projection() -> None: + xyz = np.array(((10, 20, 30), (10, 30, 50), (-20, -15, 5), + (-20, -50, 10)), dtype=np.float32) + K = np.eye(3) + + uv = np.array(((1/3, 2/3), (1/5, 3/5), (-4, -3), + (-2, -5)), dtype=np.float32) + assert np.allclose(uv, project(xyz, K)) + + uv = np.array(((1/3, 2/3), (1/5, 3/5), (0, 0), (0, 0)), dtype=np.float32) + assert np.allclose(uv, project(xyz, K, img_size=(5, 5))) diff --git a/imcui/third_party/mickey/benchmark/utils.py b/imcui/third_party/mickey/benchmark/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6faad88942f588d64272166726afaa0bd398c5 --- /dev/null +++ b/imcui/third_party/mickey/benchmark/utils.py @@ -0,0 +1,186 @@ +from pathlib import Path +import typing +import logging + +import numpy as np +from transforms3d.quaternions import qinverse, rotate_vector, qmult + +VARIANTS_ANGLE_SIN = 'sin' +VARIANTS_ANGLE_COS = 'cos' + + +def convert_world2cam_to_cam2world(q, t): + qinv = qinverse(q) + tinv = -rotate_vector(t, qinv) + return qinv, tinv + + +def load_poses(file: typing.IO, load_confidence: bool = False): + """Load poses from text file and converts them to cam2world convention (t is the camera center in world coordinates) + + The text file encodes world2cam poses with the format: + imgpath qw qx qy qz tx ty tz [confidence] + where qw qx qy qz is the quaternion encoding rotation, + and tx ty tz is the translation vector, + and confidence is a float encoding confidence, for estimated poses + """ + + expected_parts = 9 if load_confidence else 8 + + poses = dict() + for line_number, line in enumerate(file.readlines()): + parts = tuple(line.strip().split(' ')) + + # if 'tensor' in parts[-1]: + # print('ERROR: confidence is a tensor') + # parts = list(parts) + # parts[-1] = parts[-1].split('[')[-1].split(']')[0] + if len(parts) != expected_parts: + logging.warning( + f'Invalid number of fields in file {file.name} line {line_number}.' + f' Expected {expected_parts}, received {len(parts)}. Ignoring line.') + continue + + try: + name = parts[0] + if '#' in name: + logging.info(f'Ignoring comment line in {file.name} line {line_number}') + continue + frame_num = int(name[-9:-4]) + except ValueError: + logging.warning( + f'Invalid frame number in file {file.name} line {line_number}.' + f' Expected formatting "seq1/frame_00000.jpg". Ignoring line.') + continue + + try: + parts_float = tuple(map(float, parts[1:])) + if any(np.isnan(v) or np.isinf(v) for v in parts_float): + raise ValueError() + qw, qx, qy, qz, tx, ty, tz = parts_float[:7] + confidence = parts_float[7] if load_confidence else None + except ValueError: + logging.warning( + f'Error parsing pose in file {file.name} line {line_number}. Ignoring line.') + continue + + q = np.array((qw, qx, qy, qz), dtype=np.float64) + t = np.array((tx, ty, tz), dtype=np.float64) + + if np.isclose(np.linalg.norm(q), 0): + logging.warning( + f'Error parsing pose in file {file.name} line {line_number}. ' + 'Quaternion must have non-zero norm. Ignoring line.') + continue + + q, t = convert_world2cam_to_cam2world(q, t) + poses[frame_num] = (q, t, confidence) + return poses + + +def subsample_poses(poses: dict, subsample: int = 1): + return {k: v for i, (k, v) in enumerate(poses.items()) if i % subsample == 0} + + +def load_K(file_path: Path): + K = dict() + with file_path.open('r', encoding='utf-8') as f: + for line in f.readlines(): + if '#' in line: + continue + line = line.strip().split(' ') + + frame_num = int(line[0][-9:-4]) + fx, fy, cx, cy, W, H = map(float, line[1:]) + K[frame_num] = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + return K, W, H + + +def quat_angle_error(label, pred, variant=VARIANTS_ANGLE_SIN) -> np.ndarray: + assert label.shape == (4,) + assert pred.shape == (4,) + assert variant in (VARIANTS_ANGLE_SIN, VARIANTS_ANGLE_COS), \ + f"Need variant to be in ({VARIANTS_ANGLE_SIN}, {VARIANTS_ANGLE_COS})" + + if len(label.shape) == 1: + label = np.expand_dims(label, axis=0) + if len(label.shape) != 2 or label.shape[0] != 1 or label.shape[1] != 4: + raise RuntimeError(f"Unexpected shape of label: {label.shape}, expected: (1, 4)") + + if len(pred.shape) == 1: + pred = np.expand_dims(pred, axis=0) + if len(pred.shape) != 2 or pred.shape[0] != 1 or pred.shape[1] != 4: + raise RuntimeError(f"Unexpected shape of pred: {pred.shape}, expected: (1, 4)") + + label = label.astype(np.float64) + pred = pred.astype(np.float64) + + q1 = pred / np.linalg.norm(pred, axis=1, keepdims=True) + q2 = label / np.linalg.norm(label, axis=1, keepdims=True) + if variant == VARIANTS_ANGLE_COS: + d = np.abs(np.sum(np.multiply(q1, q2), axis=1, keepdims=True)) + d = np.clip(d, a_min=-1, a_max=1) + angle = 2. * np.degrees(np.arccos(d)) + elif variant == VARIANTS_ANGLE_SIN: + if q1.shape[0] != 1 or q2.shape[0] != 1: + raise NotImplementedError(f"Multiple angles is todo") + # https://www.researchgate.net/post/How_do_I_calculate_the_smallest_angle_between_two_quaternions/5d6ed4a84f3a3e1ed3656616/citation/download + sine = qmult(q1[0], qinverse(q2[0])) # note: takes first element in 2D array + # 114.59 = 2. * 180. / pi + angle = np.arcsin(np.linalg.norm(sine[1:], keepdims=True)) * 114.59155902616465 + angle = np.expand_dims(angle, axis=0) + + return angle.astype(np.float64) + + +def precision_recall(inliers, tp, failures): + """ + Computes Precision/Recall plot for a set of poses given inliers (confidence) and wether the + estimated pose error (whatever it may be) is within a threshold. + Each point in the plot is obtained by choosing a threshold for inliers (i.e. inlier_thr). + Recall measures how many images have inliers >= inlier_thr + Precision measures how many images that have inliers >= inlier_thr have + estimated pose error <= pose_threshold (measured by counting tps) + Where pose_threshold is (trans_thr[m], rot_thr[deg]) + + Inputs: + - inliers [N] + - terr [N] + - rerr [N] + - failures (int) + - pose_threshold (tuple float) + Output + - precision [N] + - recall [N] + - average_precision (scalar) + """ + + assert len(inliers) == len(tp), 'unequal shapes' + + # sort by inliers (descending order) + inliers = np.array(inliers) + sort_idx = np.argsort(inliers)[::-1] + inliers = inliers[sort_idx] + tp = np.array(tp).reshape(-1)[sort_idx] + + # get idxs where inliers change (avoid tied up values) + distinct_value_indices = np.where(np.diff(inliers))[0] + threshold_idxs = np.r_[distinct_value_indices, inliers.size - 1] + + # compute prec/recall + N = inliers.shape[0] + rec = np.arange(N, dtype=np.float32) + 1 + cum_tp = np.cumsum(tp) + prec = cum_tp[threshold_idxs] / rec[threshold_idxs] + rec = rec[threshold_idxs] / (float(N) + float(failures)) + + # invert order and ensures (prec=1, rec=0) point + last_ind = rec.searchsorted(rec[-1]) + sl = slice(last_ind, None, -1) + prec = np.r_[prec[sl], 1] + rec = np.r_[rec[sl], 0] + + # compute average precision (AUC) as the weighted average of precisions + average_precision = np.abs(np.sum(np.diff(rec) * np.array(prec)[:-1])) + + return prec, rec, average_precision diff --git a/imcui/third_party/mickey/config/MicKey/curriculum_learning.yaml b/imcui/third_party/mickey/config/MicKey/curriculum_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9892b8b84a0c04c977032804e937eca0255129a9 --- /dev/null +++ b/imcui/third_party/mickey/config/MicKey/curriculum_learning.yaml @@ -0,0 +1,97 @@ + +MODEL: 'MicKey' +DEBUG: False +MICKEY: + DINOV2: + DOWN_FACTOR: 14 + CHANNEL_DIM: 1024 + FLOAT16: True + + KP_HEADS: + BLOCKS_DIM: [512, 256, 128, 64] + BN: True + USE_SOFTMAX: True + USE_DEPTHSIGMOID: False + MAX_DEPTH: 60 + POS_ENCODING: True + + DSC_HEAD: + LAST_DIM: 128 + BLOCKS_DIM: [512, 256, 128] + BN: True + NORM_DSC: True + POS_ENCODING: True + +FEATURE_MATCHER: + TYPE: 'DualSoftmax' + DUAL_SOFTMAX: + TEMPERATURE: 0.1 + USE_DUSTBIN: True + SINKHORN: + NUM_IT: 10 + DUSTBIN_SCORE_INIT: 1. + USE_TRANSFORMER: False + +TRAINING: + NUM_GPUS: 4 + BATCH_SIZE: 12 # BS for each dataloader (in every GPU) + NUM_WORKERS: 12 + SAMPLER: 'scene_balance' + N_SAMPLES_SCENE: 100 + SAMPLE_WITH_REPLACEMENT: True + LR: 1e-4 + LOG_INTERVAL: 50 + VAL_INTERVAL: 0.5 + VAL_BATCHES: 100 + EPOCHS: 100 + +DATASET: + HEIGHT: 720 + WIDTH: 540 + + MIN_OVERLAP_SCORE: 0.0 # [train only] discard data with overlap_score < min_overlap_score + MAX_OVERLAP_SCORE: 1.0 # [train only] discard data with overlap_score < min_overlap_score + +LOSS_CLASS: + + LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR + SOFT_CLIPPING: True # It indicates if it soft-clips the loss values. + + POSE_ERR: + MAX_LOSS_VALUE: 1.5 + MAX_LOSS_SOFTVALUE: 0.8 + VCRE: + MAX_LOSS_VALUE: 90 + MAX_LOSS_SOFTVALUE: 0.8 + + GENERATE_HYPOTHESES: + SCORE_TEMPERATURE: 20 + IT_MATCHES: 20 + IT_RANSAC: 20 + INLIER_3D_TH: 0.3 + INLIER_REF_TH: 0.15 + NUM_REF_STEPS: 4 + NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability + + NULL_HYPOTHESIS: + ADD_NULL_HYPOTHESIS: True + TH_OUTLIERS: 0.35 + + CURRICULUM_LEARNING: + TRAIN_CURRICULUM: True # It indicates if MicKey should be trained with curriculum learning + TRAIN_WITH_TOPK: True # It indicates if MicKey should be trained only with top image pairs + TOPK_INIT: 30 + TOPK: 80 + + SAMPLER: + NUM_SAMPLES_MATCHES: 512 + +PROCRUSTES: + IT_MATCHES: 20 + IT_RANSAC: 100 + NUM_SAMPLED_MATCHES: 2048 + NUM_CORR_3D_3D: 3 + NUM_REFINEMENTS: 4 + TH_INLIER: 0.15 + TH_SOFT_INLIER: 0.3 + diff --git a/imcui/third_party/mickey/config/MicKey/overlap_score.yaml b/imcui/third_party/mickey/config/MicKey/overlap_score.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5dd5060dc7f4228c22b9662d2844739bb50c196 --- /dev/null +++ b/imcui/third_party/mickey/config/MicKey/overlap_score.yaml @@ -0,0 +1,96 @@ + +MODEL: 'MicKey' +DEBUG: False +MICKEY: + DINOV2: + DOWN_FACTOR: 14 + CHANNEL_DIM: 1024 + FLOAT16: True + + KP_HEADS: + BLOCKS_DIM: [512, 256, 128, 64] + BN: True + USE_SOFTMAX: True + USE_DEPTHSIGMOID: False + MAX_DEPTH: 60 + POS_ENCODING: True + + DSC_HEAD: + LAST_DIM: 128 + BLOCKS_DIM: [512, 256, 128] + BN: True + NORM_DSC: True + POS_ENCODING: True + +FEATURE_MATCHER: + TYPE: 'DualSoftmax' + DUAL_SOFTMAX: + TEMPERATURE: 0.1 + USE_DUSTBIN: True + SINKHORN: + NUM_IT: 10 + DUSTBIN_SCORE_INIT: 1. + USE_TRANSFORMER: False + +TRAINING: + NUM_GPUS: 4 + BATCH_SIZE: 12 # BS for each dataloader (in every GPU) + NUM_WORKERS: 12 + SAMPLER: 'scene_balance' + N_SAMPLES_SCENE: 100 + SAMPLE_WITH_REPLACEMENT: True + LR: 1e-4 + LOG_INTERVAL: 50 + VAL_INTERVAL: 0.5 + VAL_BATCHES: 100 + EPOCHS: 100 + +DATASET: + HEIGHT: 720 + WIDTH: 540 + + MIN_OVERLAP_SCORE: 0.4 # [train only] discard data with overlap_score < min_overlap_score + MAX_OVERLAP_SCORE: 0.8 # [train only] discard data with overlap_score < min_overlap_score + +LOSS_CLASS: + + LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR + SOFT_CLIPPING: True # It indicates if it soft-clips the loss values. + + POSE_ERR: + MAX_LOSS_VALUE: 1.5 + MAX_LOSS_SOFTVALUE: 0.8 + VCRE: + MAX_LOSS_VALUE: 90 + MAX_LOSS_SOFTVALUE: 0.8 + + GENERATE_HYPOTHESES: + SCORE_TEMPERATURE: 20 + IT_MATCHES: 20 + IT_RANSAC: 20 + INLIER_3D_TH: 0.3 + INLIER_REF_TH: 0.15 + NUM_REF_STEPS: 4 + NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability + + NULL_HYPOTHESIS: + ADD_NULL_HYPOTHESIS: True + TH_OUTLIERS: 0.35 + + CURRICULUM_LEARNING: + TRAIN_CURRICULUM: False # It indicates if MicKey should be trained with curriculum learning + TRAIN_WITH_TOPK: False # It indicates if MicKey should be trained only with top image pairs + TOPK_INIT: 30 + TOPK: 80 + + SAMPLER: + NUM_SAMPLES_MATCHES: 512 + +PROCRUSTES: + IT_MATCHES: 20 + IT_RANSAC: 100 + NUM_SAMPLED_MATCHES: 2048 + NUM_CORR_3D_3D: 3 + NUM_REFINEMENTS: 4 + TH_INLIER: 0.15 + TH_SOFT_INLIER: 0.3 \ No newline at end of file diff --git a/imcui/third_party/mickey/config/datasets/mapfree.yaml b/imcui/third_party/mickey/config/datasets/mapfree.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f44c5c88515fa92e1c769dbdf0af61ce85414d6b --- /dev/null +++ b/imcui/third_party/mickey/config/datasets/mapfree.yaml @@ -0,0 +1,10 @@ +DATASET: + DATA_SOURCE: 'MapFree' + DATA_ROOT: 'data/' + SCENES: None # should be a list [] or None. If none, use all scenes. + AUGMENTATION_TYPE: None + HEIGHT: 720 + WIDTH: 540 + MIN_OVERLAP_SCORE: 0.2 # [train only] discard data with overlap_score < min_overlap_score + MAX_OVERLAP_SCORE: 0.7 # [train only] discard data with overlap_score < min_overlap_score + SEED: 66 \ No newline at end of file diff --git a/imcui/third_party/mickey/config/default.py b/imcui/third_party/mickey/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..ce57235bf9cd851891c2780e4aa20784a19f7a6f --- /dev/null +++ b/imcui/third_party/mickey/config/default.py @@ -0,0 +1,141 @@ +from yacs.config import CfgNode as CN + +_CN = CN() + +############## Model ############## +_CN.MODEL = None # options: ['MicKey'] +_CN.DEBUG = False + +# MicKey configuration +_CN.MICKEY = CN() + +_CN.MICKEY.DINOV2 = CN() +_CN.MICKEY.DINOV2.DOWN_FACTOR = None +_CN.MICKEY.DINOV2.CHANNEL_DIM = None +_CN.MICKEY.DINOV2.FLOAT16 = None + +_CN.MICKEY.KP_HEADS = CN() +_CN.MICKEY.KP_HEADS.BLOCKS_DIM = None +_CN.MICKEY.KP_HEADS.BN = None +_CN.MICKEY.KP_HEADS.USE_SOFTMAX = None +_CN.MICKEY.KP_HEADS.USE_DEPTHSIGMOID = None +_CN.MICKEY.KP_HEADS.MAX_DEPTH = None +_CN.MICKEY.KP_HEADS.POS_ENCODING = None + +_CN.MICKEY.DSC_HEAD = CN() +_CN.MICKEY.DSC_HEAD.LAST_DIM = None +_CN.MICKEY.DSC_HEAD.BLOCKS_DIM = None +_CN.MICKEY.DSC_HEAD.BN = None +_CN.MICKEY.DSC_HEAD.NORM_DSC = None +_CN.MICKEY.DSC_HEAD.POS_ENCODING = None + + +_CN.FEATURE_MATCHER = CN() +_CN.FEATURE_MATCHER.TYPE = None +_CN.FEATURE_MATCHER.DUAL_SOFTMAX = CN() +_CN.FEATURE_MATCHER.DUAL_SOFTMAX.TEMPERATURE = None +_CN.FEATURE_MATCHER.DUAL_SOFTMAX.USE_DUSTBIN = None +_CN.FEATURE_MATCHER.SINKHORN = CN() +_CN.FEATURE_MATCHER.SINKHORN.NUM_IT = None +_CN.FEATURE_MATCHER.SINKHORN.DUSTBIN_SCORE_INIT = None +_CN.FEATURE_MATCHER.USE_TRANSFORMER = None +_CN.FEATURE_MATCHER.TOP_KEYPOINTS = False + +# LOSS_CLASS +_CN.LOSS_CLASS = CN() +_CN.LOSS_CLASS.LOSS_FUNCTION = None +_CN.LOSS_CLASS.SOFT_CLIPPING = None + +_CN.LOSS_CLASS.POSE_ERR = CN() +_CN.LOSS_CLASS.POSE_ERR.MAX_LOSS_VALUE = None +_CN.LOSS_CLASS.POSE_ERR.MAX_LOSS_SOFTVALUE = None + +_CN.LOSS_CLASS.VCRE = CN() +_CN.LOSS_CLASS.VCRE.MAX_LOSS_VALUE = None +_CN.LOSS_CLASS.VCRE.MAX_LOSS_SOFTVALUE = None + +_CN.LOSS_CLASS.GENERATE_HYPOTHESES = CN() +_CN.LOSS_CLASS.GENERATE_HYPOTHESES.SCORE_TEMPERATURE = None +_CN.LOSS_CLASS.GENERATE_HYPOTHESES.IT_MATCHES = None +_CN.LOSS_CLASS.GENERATE_HYPOTHESES.IT_RANSAC = None +_CN.LOSS_CLASS.GENERATE_HYPOTHESES.INLIER_3D_TH = None +_CN.LOSS_CLASS.GENERATE_HYPOTHESES.INLIER_REF_TH = None +_CN.LOSS_CLASS.GENERATE_HYPOTHESES.NUM_REF_STEPS = None +_CN.LOSS_CLASS.GENERATE_HYPOTHESES.NUM_CORR_3d3d = None + +_CN.LOSS_CLASS.CURRICULUM_LEARNING = CN() +_CN.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_CURRICULUM = None +_CN.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_WITH_TOPK = None +_CN.LOSS_CLASS.CURRICULUM_LEARNING.TOPK_INIT = None +_CN.LOSS_CLASS.CURRICULUM_LEARNING.TOPK = None + +_CN.LOSS_CLASS.NULL_HYPOTHESIS = CN() +_CN.LOSS_CLASS.NULL_HYPOTHESIS.ADD_NULL_HYPOTHESIS = None +_CN.LOSS_CLASS.NULL_HYPOTHESIS.TH_OUTLIERS = None + +_CN.LOSS_CLASS.SAMPLER = CN() +_CN.LOSS_CLASS.SAMPLER.NUM_SAMPLES_MATCHES = None + + +# Procrustes RANSAC options +_CN.PROCRUSTES = CN() +_CN.PROCRUSTES.IT_MATCHES = None +_CN.PROCRUSTES.IT_RANSAC = None +_CN.PROCRUSTES.NUM_SAMPLED_MATCHES = None +_CN.PROCRUSTES.NUM_CORR_3D_3D = None +_CN.PROCRUSTES.NUM_REFINEMENTS = None +_CN.PROCRUSTES.TH_INLIER = None +_CN.PROCRUSTES.TH_SOFT_INLIER = None + + + + +# Training Procrustes RANSAC options +_CN.PROCRUSTES_TRAINING = CN() +_CN.PROCRUSTES_TRAINING.MAX_CORR_DIST = None +_CN.PROCRUSTES_TRAINING.REFINE = False #refine pose with ICP + + +############## Dataset ############## +_CN.DATASET = CN() +# 1. data config +_CN.DATASET.DATA_SOURCE = None # options: ['ScanNet', '7Scenes', 'MapFree'] +_CN.DATASET.SCENES = None # scenes to use (for 7Scenes/MapFree); should be a list []; If none, use all scenes. +_CN.DATASET.DATA_ROOT = None # path to dataset folder +_CN.DATASET.SEED = None # SEED for dataset generation +_CN.DATASET.NPZ_ROOT = None # path to npz files containing pairs of frame indices per sample +_CN.DATASET.MIN_OVERLAP_SCORE = None # discard data with overlap_score < min_overlap_score +_CN.DATASET.MAX_OVERLAP_SCORE = None # discard data with overlap_score > max_overlap_score +_CN.DATASET.CONSECUTIVE_PAIRS = None # options: [None, 'colorjitter'] +_CN.DATASET.FRAME_RATE = None # options: [None, 'colorjitter'] +_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'colorjitter'] +_CN.DATASET.BLACK_WHITE = False # if true, transform images to black & white +_CN.DATASET.PAIRS_TXT = CN() # Path to text file defining the train/val/test pairs (7Scenes) +_CN.DATASET.PAIRS_TXT.TRAIN = None +_CN.DATASET.PAIRS_TXT.VAL = None +_CN.DATASET.PAIRS_TXT.TEST = None +_CN.DATASET.PAIRS_TXT.ONE_NN = False # If true, keeps only reference image w/ highest similarity to each query +_CN.DATASET.HEIGHT = None +_CN.DATASET.WIDTH = None + +############# TRAINING ############# +_CN.TRAINING = CN() +# Data Loader settings +_CN.TRAINING.BATCH_SIZE = None +_CN.TRAINING.NUM_WORKERS = None +_CN.TRAINING.NUM_GPUS = None +_CN.TRAINING.SAMPLER = None # options: ['random', 'scene_balance'] +_CN.TRAINING.N_SAMPLES_SCENE = None # if 'scene_balance' sampler, the number of samples to get per scene +_CN.TRAINING.SAMPLE_WITH_REPLACEMENT = None # if 'scene_balance' sampler, whether to sample with replacement + +# Training settings +_CN.TRAINING.LR = None +_CN.TRAINING.LR_STEP_INTERVAL = None +_CN.TRAINING.LR_STEP_GAMMA = None # multiplicative factor of LR every LR_STEP_ITERATIONS +_CN.TRAINING.VAL_INTERVAL = None +_CN.TRAINING.VAL_BATCHES = None +_CN.TRAINING.LOG_INTERVAL = None +_CN.TRAINING.EPOCHS = None +_CN.TRAINING.GRAD_CLIP = 0. # Indicates the L2 norm at which to clip the gradient. Disabled if 0 + +cfg = _CN \ No newline at end of file diff --git a/imcui/third_party/mickey/demo_inference.py b/imcui/third_party/mickey/demo_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..760c734efee56b0a7378878f9472027d5667e9be --- /dev/null +++ b/imcui/third_party/mickey/demo_inference.py @@ -0,0 +1,130 @@ +import torch +import argparse +from lib.models.builder import build_model +from lib.datasets.utils import correct_intrinsic_scale +from lib.models.MicKey.modules.utils.training_utils import colorize, generate_heat_map +from config.default import cfg +import numpy as np +from pathlib import Path +import cv2 + +def prepare_score_map(scs, img, temperature=0.5): + + score_map = generate_heat_map(scs, img, temperature) + + score_map = 255 * score_map.permute(1, 2, 0).numpy() + + return score_map + +def colorize_depth(value, vmin=None, vmax=None, cmap='magma_r', invalid_val=-99, invalid_mask=None, background_color=(0, 0, 0, 255), gamma_corrected=False, value_transform=None): + + img = colorize(value, vmin, vmax, cmap, invalid_val, invalid_mask, background_color, gamma_corrected, value_transform) + + shape_im = img.shape + img = np.asarray(img, np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA) + img = cv2.resize(img, (shape_im[1]*14, shape_im[0]*14), interpolation=cv2.INTER_LINEAR) + + return img + +def read_color_image(path, resize=(540, 720)): + """ + Args: + resize (tuple): align image to depthmap, in (w, h). + Returns: + image (torch.tensor): (3, h, w) + """ + # read and resize image + cv_type = cv2.IMREAD_COLOR + image = cv2.imread(str(path), cv_type) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if resize: + image = cv2.resize(image, resize) + + # (h, w, 3) -> (3, h, w) and normalized + image = torch.from_numpy(image).float().permute(2, 0, 1) / 255 + + return image.unsqueeze(0) + +def read_intrinsics(path_intrinsics, resize=None): + Ks = {} + with Path(path_intrinsics).open('r') as f: + for line in f.readlines(): + if '#' in line: + continue + + line = line.strip().split(' ') + img_name = line[0] + fx, fy, cx, cy, W, H = map(float, line[1:]) + + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + if resize is not None: + K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H) + Ks[img_name] = K + return Ks + +def run_demo_inference(args): + + # Select device + use_cuda = torch.cuda.is_available() + device = torch.device('cuda:0' if use_cuda else 'cpu') + + print('Preparing data...') + + # Prepare config file + cfg.merge_from_file(args.config) + + # Prepare the model + model = build_model(cfg, checkpoint=args.checkpoint) + + # Load demo images + im0 = read_color_image(args.im_path_ref).to(device) + im1 = read_color_image(args.im_path_dst).to(device) + + # Load intrinsics + K = read_intrinsics(args.intrinsics) + + # Prepare data for MicKey + data = {} + data['image0'] = im0 + data['image1'] = im1 + data['K_color0'] = torch.from_numpy(K['im0.jpg']).unsqueeze(0).to(device) + data['K_color1'] = torch.from_numpy(K['im1.jpg']).unsqueeze(0).to(device) + + # Run inference + print('Running MicKey relative pose estimation...') + model(data) + + # Pose, inliers and score are stored in: + # data['R'] = R + # data['t'] = t + # data['inliers'] = inliers + # data['inliers_list'] = inliers_list + + print('Saving depth and score maps in image directory ...') + depth0_map = colorize_depth(data['depth0_map'][0], invalid_mask=(data['depth0_map'][0] < 0.001).cpu()[0]) + depth1_map = colorize_depth(data['depth1_map'][0], invalid_mask=(data['depth1_map'][0] < 0.001).cpu()[0]) + score0_map = prepare_score_map(data['scr0'][0], data['image0'][0], temperature=0.5) + score1_map = prepare_score_map(data['scr1'][0], data['image1'][0], temperature=0.5) + + ext_im0 = args.im_path_ref.split('.')[-1] + ext_im1 = args.im_path_dst.split('.')[-1] + + cv2.imwrite(args.im_path_ref.replace(ext_im0, 'score.jpg'), score0_map) + cv2.imwrite(args.im_path_dst.replace(ext_im1, 'score.jpg'), score1_map) + + cv2.imwrite(args.im_path_ref.replace(ext_im0, 'depth.jpg'), depth0_map) + cv2.imwrite(args.im_path_dst.replace(ext_im1, 'depth.jpg'), depth1_map) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--im_path_ref', help='path to reference image', default='data/toy_example/im0.jpg') + parser.add_argument('--im_path_dst', help='path to destination image', default='data/toy_example/im1.jpg') + parser.add_argument('--intrinsics', help='path to intrinsics file', default='data/toy_example/intrinsics.txt') + parser.add_argument('--config', help='path to config file', default='weights/mickey_weights/config.yaml') + parser.add_argument('--checkpoint', help='path to model checkpoint', + default='weights/mickey_weights/mickey.ckpt') + args = parser.parse_args() + + run_demo_inference(args) + diff --git a/imcui/third_party/mickey/resources/environment.yml b/imcui/third_party/mickey/resources/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..6b0039dfc4dae504750d54f8539ffedf8c046242 --- /dev/null +++ b/imcui/third_party/mickey/resources/environment.yml @@ -0,0 +1,28 @@ +name: mickey +channels: + - conda-forge + - defaults +dependencies: + - python=3.8.17 + - pip=23.2.1 + - pip: + - einops==0.6.1 + - lazy-loader==0.3 + - lightning-utilities==0.9.0 + - matplotlib==3.7.2 + - numpy==1.24.4 + - omegaconf==2.3.0 + - open3d==0.17.0 + - opencv-python==4.8.0.74 + - protobuf==4.23.4 + - pytorch-lightning==2.0.6 + - tensorboard==2.13.0 + - tensorboard-data-server==0.7.1 + - timm==0.6.7 + - torch==2.0.1 + - torchmetrics==1.0.2 + - torchvision==0.15.2 + - tqdm==4.65.1 + - transforms3d==0.4.1 + - xformers==0.0.20 + - yacs==0.1.8 diff --git a/imcui/third_party/mickey/submission.py b/imcui/third_party/mickey/submission.py new file mode 100644 index 0000000000000000000000000000000000000000..56f1170dc039a8eb68d5cd10b3293cc0079b75a5 --- /dev/null +++ b/imcui/third_party/mickey/submission.py @@ -0,0 +1,107 @@ +import argparse +from pathlib import Path +from collections import defaultdict +from dataclasses import dataclass +from zipfile import ZipFile + +import torch +import numpy as np +from tqdm import tqdm + +from config.default import cfg +from lib.datasets.datamodules import DataModule +from lib.models.builder import build_model +from lib.utils.data import data_to_model_device +from transforms3d.quaternions import mat2quat + +@dataclass +class Pose: + image_name: str + q: np.ndarray + t: np.ndarray + inliers: float + + def __str__(self) -> str: + formatter = {'float': lambda v: f'{v:.6f}'} + max_line_width = 1000 + q_str = np.array2string(self.q, formatter=formatter, max_line_width=max_line_width)[1:-1] + t_str = np.array2string(self.t, formatter=formatter, max_line_width=max_line_width)[1:-1] + return f'{self.image_name} {q_str} {t_str} {self.inliers}' + + +def predict(loader, model): + results_dict = defaultdict(list) + + for data in tqdm(loader): + + # run inference + data = data_to_model_device(data, model) + with torch.no_grad(): + R_batched, t_batched = model(data) + + for i_batch in range(len(data['scene_id'])): + R = R_batched[i_batch].unsqueeze(0).detach().cpu().numpy() + t = t_batched[i_batch].reshape(-1).detach().cpu().numpy() + inliers = data['inliers'][i_batch].item() + + scene = data['scene_id'][i_batch] + query_img = data['pair_names'][1][i_batch] + + # ignore frames without poses (e.g. not enough feature matches) + if np.isnan(R).any() or np.isnan(t).any() or np.isinf(t).any(): + continue + + # populate results_dict + estimated_pose = Pose(image_name=query_img, + q=mat2quat(R).reshape(-1), + t=t.reshape(-1), + inliers=inliers) + results_dict[scene].append(estimated_pose) + + return results_dict + + +def save_submission(results_dict: dict, output_path: Path): + with ZipFile(output_path, 'w') as zip: + for scene, poses in results_dict.items(): + poses_str = '\n'.join((str(pose) for pose in poses)) + zip.writestr(f'pose_{scene}.txt', poses_str.encode('utf-8')) + + +def eval(args): + # Load configs + cfg.merge_from_file('config/datasets/mapfree.yaml') + cfg.merge_from_file(args.config) + + # Create dataloader + if args.split == 'test': + cfg.TRAINING.BATCH_SIZE = 8 + cfg.TRAINING.NUM_WORKERS = 8 + dataloader = DataModule(cfg, drop_last_val=False).test_dataloader() + elif args.split == 'val': + cfg.TRAINING.BATCH_SIZE = 16 + cfg.TRAINING.NUM_WORKERS = 8 + dataloader = DataModule(cfg, drop_last_val=False).val_dataloader() + else: + raise NotImplemented(f'Invalid split: {args.split}') + + # Create model + model = build_model(cfg, args.checkpoint) + + # Get predictions from model + results_dict = predict(dataloader, model) + + # Save predictions to txt per scene within zip + args.output_root.mkdir(parents=True, exist_ok=True) + save_submission(results_dict, args.output_root / 'submission.zip') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', help='path to config file') + parser.add_argument('--checkpoint', + help='path to model checkpoint (models with learned parameters)', default='') + parser.add_argument('--output_root', '-o', type=Path, default=Path('results/')) + parser.add_argument('--split', choices=('val', 'test'), default='test', + help='Dataset split to use for evaluation. Choose from test or val. Default: test') + args = parser.parse_args() + eval(args) diff --git a/imcui/third_party/mickey/train.py b/imcui/third_party/mickey/train.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2452ef107d834b711da5ba90fc3782e9c36a88 --- /dev/null +++ b/imcui/third_party/mickey/train.py @@ -0,0 +1,91 @@ +import argparse +import os +# do this before importing numpy! (doing it right up here in case numpy is dependency of e.g. json) +os.environ["MKL_NUM_THREADS"] = "1" # noqa: E402 +os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa: E402 +os.environ["OMP_NUM_THREADS"] = "1" # noqa: E402 +os.environ["OPENBLAS_NUM_THREADS"] = "1" # noqa: E402 + +import pytorch_lightning as pl +import torch +from pytorch_lightning.loggers import TensorBoardLogger + +from config.default import cfg +from lib.datasets.datamodules import DataModuleTraining +from lib.models.MicKey.model import MicKeyTrainingModel +from lib.models.MicKey.modules.utils.training_utils import create_exp_name, create_result_dir +import random +import shutil + +def train_model(args): + + cfg.merge_from_file(args.dataset_config) + cfg.merge_from_file(args.config) + + exp_name = create_exp_name(args.experiment, cfg) + print('Start training of ' + exp_name) + + cfg.DATASET.SEED = random.randint(0, 1000000) + + model = MicKeyTrainingModel(cfg) + + checkpoint_vcre_callback = pl.callbacks.ModelCheckpoint( + filename='{epoch}-best_vcre', + save_last=True, + save_top_k=1, + verbose=True, + monitor='val_vcre/auc_vcre', + mode='max' + ) + + checkpoint_pose_callback = pl.callbacks.ModelCheckpoint( + filename='{epoch}-best_pose', + save_last=True, + save_top_k=1, + verbose=True, + monitor='val_AUC_pose/auc_pose', + mode='max' + ) + + epochend_callback = pl.callbacks.ModelCheckpoint( + filename='e{epoch}-last', + save_top_k=1, + every_n_epochs=1, + save_on_train_epoch_end=True + ) + + lr_monitoring_callback = pl.callbacks.LearningRateMonitor(logging_interval='step') + logger = TensorBoardLogger(save_dir=args.path_weights, name=exp_name) + + trainer = pl.Trainer(devices=cfg.TRAINING.NUM_GPUS, + log_every_n_steps=cfg.TRAINING.LOG_INTERVAL, + val_check_interval=cfg.TRAINING.VAL_INTERVAL, + limit_val_batches=cfg.TRAINING.VAL_BATCHES, + max_epochs=cfg.TRAINING.EPOCHS, + logger=logger, + callbacks=[checkpoint_pose_callback, lr_monitoring_callback, epochend_callback, checkpoint_vcre_callback], + num_sanity_val_steps=0, + gradient_clip_val=cfg.TRAINING.GRAD_CLIP) + + datamodule_end = DataModuleTraining(cfg) + print('Training with {:.2f}/{:.2f} image overlap'.format(cfg.DATASET.MIN_OVERLAP_SCORE, cfg.DATASET.MAX_OVERLAP_SCORE)) + + create_result_dir(logger.log_dir + '/config.yaml') + shutil.copyfile(args.config, logger.log_dir + '/config.yaml') + + if args.resume: + ckpt_path = args.resume + else: + ckpt_path = None + + trainer.fit(model, datamodule_end, ckpt_path=ckpt_path) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', help='path to config file', default='config/MicKey/curriculum_learning.yaml') + parser.add_argument('--dataset_config', help='path to dataset config file', default='config/datasets/mapfree.yaml') + parser.add_argument('--experiment', help='experiment name', default='MicKey_default') + parser.add_argument('--path_weights', help='path to the directory to save the weights', default='weights/') + parser.add_argument('--resume', help='resume from checkpoint path', default=None) + args = parser.parse_args() + train_model(args) \ No newline at end of file diff --git a/imcui/third_party/omniglue/__init__.py b/imcui/third_party/omniglue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e445b0a1160f3f06664873a39e892958b8d3511 --- /dev/null +++ b/imcui/third_party/omniglue/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""omniglue API.""" + +# A new PyPI release will be pushed every time `__version__` is increased. +# When changing this, also update the CHANGELOG.md. +__version__ = "0.1.0" diff --git a/imcui/third_party/omniglue/demo.py b/imcui/third_party/omniglue/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d763ee6c74e31c218cc3f0f8f97dc44d90c127 --- /dev/null +++ b/imcui/third_party/omniglue/demo.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demo script for performing OmniGlue inference.""" + +import sys +import time +import matplotlib.pyplot as plt +import numpy as np +from src import omniglue +from src.omniglue import utils +from PIL import Image + + +def main(argv) -> None: + if len(argv) != 3: + print("error - usage: python demo.py ") + return + + # Load images. + print("> Loading images...") + image0 = np.array(Image.open(argv[1])) + image1 = np.array(Image.open(argv[2])) + + # Load models. + print("> Loading OmniGlue (and its submodules: SuperPoint & DINOv2)...") + start = time.time() + og = omniglue.OmniGlue( + og_export="./models/omniglue.onnx", + sp_export="./models/sp_v6.onnx", + dino_export="./models/dinov2_vitb14_pretrain.pth", + ) + print(f"> \tTook {time.time() - start} seconds.") + + # Perform inference. + print("> Finding matches...") + start = time.time() + match_kp0, match_kp1, match_confidences = og.FindMatches(image0, image1) + num_matches = match_kp0.shape[0] + print(f"> \tFound {num_matches} matches.") + print(f"> \tTook {time.time() - start} seconds.") + + # Filter by confidence (0.02). + print("> Filtering matches...") + match_threshold = 0.02 # Choose any value [0.0, 1.0). + keep_idx = [] + for i in range(match_kp0.shape[0]): + if match_confidences[i] > match_threshold: + keep_idx.append(i) + num_filtered_matches = len(keep_idx) + match_kp0 = match_kp0[keep_idx] + match_kp1 = match_kp1[keep_idx] + match_confidences = match_confidences[keep_idx] + print( + f"> \tFound {num_filtered_matches}/{num_matches} above threshold {match_threshold}" + ) + + # Visualize. + print("> Visualizing matches...") + viz = utils.visualize_matches( + image0, + image1, + match_kp0, + match_kp1, + np.eye(num_filtered_matches), + show_keypoints=True, + highlight_unmatched=True, + title=f"{num_filtered_matches} matches", + line_width=2, + ) + plt.figure(figsize=(20, 10), dpi=100, facecolor="w", edgecolor="k") + plt.axis("off") + plt.imshow(viz) + plt.imsave("./demo_output.png", viz) + print("> \tSaved visualization to ./demo_output.png") + + +if __name__ == "__main__": + main(sys.argv) diff --git a/imcui/third_party/omniglue/src/__init__.py b/imcui/third_party/omniglue/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/omniglue/src/omniglue/__init__.py b/imcui/third_party/omniglue/src/omniglue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2972c9f07023e8809f1515ee489cc61d86a3e5 --- /dev/null +++ b/imcui/third_party/omniglue/src/omniglue/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import omniglue_extract + +OmniGlue = omniglue_extract.OmniGlue diff --git a/imcui/third_party/omniglue/src/omniglue/dino_extract.py b/imcui/third_party/omniglue/src/omniglue/dino_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..5dea2a6b42b93f9f20c4c81b52533b49cfab18fb --- /dev/null +++ b/imcui/third_party/omniglue/src/omniglue/dino_extract.py @@ -0,0 +1,210 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for performing DINOv2 inference.""" + +import cv2 +import sys +import numpy as np +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent.parent / "third_party")) +from dinov2 import dino + +from . import utils +import torch + + +class DINOExtract: + """Class to initialize DINO model and extract features from an image.""" + + def __init__(self, cpt_path: str, feature_layer: int = 1): + self.feature_layer = feature_layer + self.model = dino.vit_base() + state_dict_raw = torch.load(cpt_path, map_location="cpu") + + # state_dict = {} + # for k, v in state_dict_raw.items(): + # state_dict[k.replace('blocks', 'blocks.0')] = v + + self.model.load_state_dict(state_dict_raw) + self.model.eval() + + self.image_size_max = 630 + + self.h_down_rate = self.model.patch_embed.patch_size[0] + self.w_down_rate = self.model.patch_embed.patch_size[1] + + def __call__(self, image) -> np.ndarray: + return self.forward(image) + + def forward(self, image: np.ndarray) -> np.ndarray: + """Feeds image through DINO ViT model to extract features. + + Args: + image: (H, W, 3) numpy array, decoded image bytes, value range [0, 255]. + + Returns: + features: (H // 14, W // 14, C) numpy array image features. + """ + image = self._resize_input_image(image) + image_processed = self._process_image(image) + image_processed = image_processed.unsqueeze(0).float() + features = self.extract_feature(image_processed) + features = features.squeeze(0).permute(1, 2, 0).cpu().numpy() + return features + + def _resize_input_image( + self, image: np.ndarray, interpolation=cv2.INTER_LINEAR + ): + """Resizes image such that both dimensions are divisble by down_rate.""" + h_image, w_image = image.shape[:2] + h_larger_flag = h_image > w_image + large_side_image = max(h_image, w_image) + + # resize the image with the largest side length smaller than a threshold + # to accelerate ViT backbone inference (which has quadratic complexity). + if large_side_image > self.image_size_max: + if h_larger_flag: + h_image_target = self.image_size_max + w_image_target = int(self.image_size_max * w_image / h_image) + else: + w_image_target = self.image_size_max + h_image_target = int(self.image_size_max * h_image / w_image) + else: + h_image_target = h_image + w_image_target = w_image + + h, w = ( + h_image_target // self.h_down_rate, + w_image_target // self.w_down_rate, + ) + h_resize, w_resize = h * self.h_down_rate, w * self.w_down_rate + image = cv2.resize( + image, (w_resize, h_resize), interpolation=interpolation + ) + return image + + def _process_image(self, image: np.ndarray) -> torch.Tensor: + """Turn image into pytorch tensor and normalize it.""" + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + + image_processed = image / 255.0 + image_processed = (image_processed - mean) / std + image_processed = torch.from_numpy(image_processed).permute(2, 0, 1) + return image_processed + + def extract_feature(self, image): + """Extracts features from image. + + Args: + image: (B, 3, H, W) torch tensor, normalized with ImageNet mean/std. + + Returns: + features: (B, C, H//14, W//14) torch tensor image features. + """ + b, _, h_origin, w_origin = image.shape + out = self.model.get_intermediate_layers(image, n=self.feature_layer)[0] + h = int(h_origin / self.h_down_rate) + w = int(w_origin / self.w_down_rate) + dim = out.shape[-1] + out = out.reshape(b, h, w, dim).permute(0, 3, 1, 2).detach() + return out + + +def _preprocess_shape( + h_image, w_image, image_size_max=630, h_down_rate=14, w_down_rate=14 +): + h_image = h_image.squeeze() + w_image = w_image.squeeze() + + h_larger_flag = h_image > w_image + large_side_image = max(h_image, w_image) + + def resize_h_larger(): + h_image_target = image_size_max + w_image_target = int(image_size_max * w_image / h_image) + return h_image_target, w_image_target + + def resize_w_larger_or_equal(): + w_image_target = image_size_max + h_image_target = int(image_size_max * h_image / w_image) + return h_image_target, w_image_target + + def keep_original(): + return h_image, w_image + + if large_side_image > image_size_max: + if h_larger_flag: + h_image_target, w_image_target = resize_h_larger() + else: + h_image_target, w_image_target = resize_w_larger_or_equal() + else: + h_image_target, w_image_target = keep_original() + + h = h_image_target // h_down_rate + w = w_image_target // w_down_rate + h_resize = torch.tensor(h * h_down_rate) + w_resize = torch.tensor(w * w_down_rate) + + h_resize = h_resize.unsqueeze(0) + w_resize = w_resize.unsqueeze(0) + + return h_resize, w_resize + + +def get_dino_descriptors(dino_features, keypoints, height, width, feature_dim): + """Get DINO descriptors using Superpoint keypoints. + + Args: + dino_features: DINO features in 1-D. + keypoints: Superpoint keypoint locations, in format (x, y), in pixels, shape + (N, 2). + height: image height, type torch int32. + width: image width, type torch int32. + feature_dim: DINO feature channel size, type torch int32. + + Returns: + Interpolated DINO descriptors. + """ + height_1d = height.reshape([1]) + width_1d = width.reshape([1]) + + height_1d_resized, width_1d_resized = _preprocess_shape( + height_1d, width_1d, image_size_max=630, h_down_rate=14, w_down_rate=14 + ) + + height_feat = height_1d_resized // 14 + width_feat = width_1d_resized // 14 + feature_dim_1d = torch.tensor(feature_dim).reshape([1]) + + dino_features = dino_features.reshape( + height_feat, width_feat, feature_dim_1d + ) + + img_size = torch.cat([width_1d, height_1d], dim=0).float() + feature_size = torch.cat([width_feat, height_feat], dim=0).float() + keypoints_feature = ( + keypoints[0] / img_size.unsqueeze(0) * feature_size.unsqueeze(0) + ) + + dino_descriptors = [] + for kp in keypoints_feature: + dino_descriptors.append( + utils.lookup_descriptor_bilinear(kp.numpy(), dino_features) + ) + dino_descriptors = torch.tensor( + np.array(dino_descriptors), dtype=torch.float32 + ) + return dino_descriptors diff --git a/imcui/third_party/omniglue/src/omniglue/omniglue_extract.py b/imcui/third_party/omniglue/src/omniglue/omniglue_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..e7dd6cfd6e18cf045e78c4c31ee834617247a76f --- /dev/null +++ b/imcui/third_party/omniglue/src/omniglue/omniglue_extract.py @@ -0,0 +1,183 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for performing OmniGlue inference, plus (optionally) SP/DINO.""" +import cv2 +import torch +import numpy as np +import onnxruntime + +from . import dino_extract +from . import superpoint_extract +from . import utils + + +DINO_FEATURE_DIM = 768 +MATCH_THRESHOLD = 1e-3 + + +class OmniGlue: + # TODO(omniglue): class docstring + + def __init__( + self, + og_export: str, + sp_export: str = None, + dino_export: str = None, + max_keypoints: int = 1024, + ) -> None: + self.max_keypoints = max_keypoints + self.matcher = onnxruntime.InferenceSession(og_export) + if sp_export is not None: + self.sp_extract = superpoint_extract.SuperPointExtract(sp_export) + if dino_export is not None: + self.dino_extract = dino_extract.DINOExtract( + dino_export, feature_layer=1 + ) + + def FindMatches( + self, + image0: np.ndarray, + image1: np.ndarray, + max_keypoints: int = 1024, + ): + """TODO(omniglue): docstring.""" + height0, width0 = image0.shape[:2] + height1, width1 = image1.shape[:2] + # TODO: numpy to torch inputs + sp_features0 = self.sp_extract(image0, num_features=max_keypoints) + sp_features1 = self.sp_extract(image1, num_features=max_keypoints) + dino_features0 = self.dino_extract(image0) + dino_features1 = self.dino_extract(image1) + dino_descriptors0 = dino_extract.get_dino_descriptors( + dino_features0, + sp_features0, + torch.tensor(height0), + torch.tensor(width0), + DINO_FEATURE_DIM, + ) + dino_descriptors1 = dino_extract.get_dino_descriptors( + dino_features1, + sp_features1, + torch.tensor(height1), + torch.tensor(width1), + DINO_FEATURE_DIM, + ) + + inputs = self._construct_inputs( + width0, + height0, + width1, + height1, + sp_features0, + sp_features1, + dino_descriptors0, + dino_descriptors1, + ) + + og_outputs = self.matcher.run(None, inputs) + soft_assignment = torch.from_numpy(og_outputs[0][:, :-1, :-1]) + + match_matrix = ( + utils.soft_assignment_to_match_matrix( + soft_assignment, MATCH_THRESHOLD + ) + .numpy() + .squeeze() + ) + + # Filter out any matches with 0.0 confidence keypoints. + match_indices = np.argwhere(match_matrix) + keep = [] + for i in range(match_indices.shape[0]): + match = match_indices[i, :] + if (sp_features0[2][match[0]] > 0.0) and ( + sp_features1[2][match[1]] > 0.0 + ): + keep.append(i) + match_indices = match_indices[keep] + + # Format matches in terms of keypoint locations. + match_kp0s = [] + match_kp1s = [] + match_confidences = [] + for match in match_indices: + match_kp0s.append(sp_features0[0][match[0], :]) + match_kp1s.append(sp_features1[0][match[1], :]) + match_confidences.append(soft_assignment[0, match[0], match[1]]) + match_kp0s = np.array(match_kp0s) + match_kp1s = np.array(match_kp1s) + match_confidences = np.array(match_confidences) + return match_kp0s, match_kp1s, match_confidences + + ### Private methods ### + + def _construct_inputs( + self, + width0, + height0, + width1, + height1, + sp_features0, + sp_features1, + dino_descriptors0, + dino_descriptors1, + ): + keypoints0 = sp_features0[0] + keypoints1 = sp_features1[0] + descriptors0 = sp_features0[1] + descriptors1 = sp_features1[1] + scores0 = sp_features0[2] + scores1 = sp_features1[2] + descriptors0_dino = dino_descriptors0 + descriptors1_dino = dino_descriptors1 + if isinstance(keypoints0, torch.Tensor): + keypoints0 = keypoints0.detach().numpy() + if isinstance(keypoints1, torch.Tensor): + keypoints1 = keypoints1.detach().numpy() + if isinstance(descriptors0, torch.Tensor): + descriptors0 = descriptors0.detach().numpy() + if isinstance(descriptors1, torch.Tensor): + descriptors1 = descriptors1.detach().numpy() + if isinstance(scores0, torch.Tensor): + scores0 = scores0.detach().numpy() + if isinstance(scores1, torch.Tensor): + scores1 = scores1.detach().numpy() + if isinstance(descriptors0_dino, torch.Tensor): + descriptors0_dino = descriptors0_dino.detach().numpy() + if isinstance(descriptors1_dino, torch.Tensor): + descriptors1_dino = descriptors1_dino.detach().numpy() + inputs = { + "keypoints0": np.expand_dims(keypoints0, axis=0).astype(np.float32), + "keypoints1": np.expand_dims(keypoints1, axis=0).astype(np.float32), + "descriptors0": np.expand_dims(descriptors0, axis=0).astype( + np.float32 + ), + "descriptors1": np.expand_dims(descriptors1, axis=0).astype( + np.float32 + ), + "scores0": np.expand_dims( + np.expand_dims(scores0, axis=0), axis=-1 + ).astype(np.float32), + "scores1": np.expand_dims( + np.expand_dims(scores1, axis=0), axis=-1 + ).astype(np.float32), + "descriptors0_dino": np.expand_dims(descriptors0_dino, axis=0), + "descriptors1_dino": np.expand_dims(descriptors1_dino, axis=0), + "width0": np.expand_dims(width0, axis=0).astype(np.int32), + "width1": np.expand_dims(width1, axis=0).astype(np.int32), + "height0": np.expand_dims(height0, axis=0).astype(np.int32), + "height1": np.expand_dims(height1, axis=0).astype(np.int32), + } + return inputs diff --git a/imcui/third_party/omniglue/src/omniglue/superpoint_extract.py b/imcui/third_party/omniglue/src/omniglue/superpoint_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..1554de303f8b95c66150da537bc6fcf3aba58ee1 --- /dev/null +++ b/imcui/third_party/omniglue/src/omniglue/superpoint_extract.py @@ -0,0 +1,212 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for performing SuperPoint inference.""" + +import math +from typing import Optional, Tuple + +import cv2 +import numpy as np +from . import utils +import onnxruntime + + +class SuperPointExtract: + """Class to initialize SuperPoint model and extract features from an image. + + To stay consistent with SuperPoint training and eval configurations, resize + images to (320x240) or (640x480). + + Attributes + model_path: string, filepath to saved SuperPoint ONNX model weights. + """ + + def __init__(self, model_path: str): + self.model_path = model_path + self.net = onnxruntime.InferenceSession(self.model_path) + + def __call__( + self, + image, + segmentation_mask=None, + num_features=1024, + pad_random_features=False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + return self.compute( + image, + segmentation_mask=segmentation_mask, + num_features=num_features, + pad_random_features=pad_random_features, + ) + + def compute( + self, + image: np.ndarray, + segmentation_mask: Optional[np.ndarray] = None, + num_features: int = 1024, + pad_random_features: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Feeds image through SuperPoint model to extract keypoints and features. + + Args: + image: (H, W, 3) numpy array, decoded image bytes. + segmentation_mask: (H, W) binary numpy array or None. If not None, + extracted keypoints are restricted to being within the mask. + num_features: max number of features to extract (or 0 to indicate keeping + all extracted features). + pad_random_features: if True, adds randomly sampled keypoints to the + output such that there are exactly 'num_features' keypoints. Descriptors + for these sampled keypoints are taken from the network's descriptor map + output, and scores are set to 0. No action taken if num_features = 0. + + Returns: + keypoints: (N, 2) numpy array, coordinates of keypoints as floats. + descriptors: (N, 256) numpy array, descriptors for keypoints as floats. + scores: (N, 1) numpy array, confidence values for keypoints as floats. + """ + + # Resize image so both dimensions are divisible by 8. + image, keypoint_scale_factors = self._resize_input_image(image) + if segmentation_mask is not None: + segmentation_mask, _ = self._resize_input_image( + segmentation_mask, interpolation=cv2.INTER_NEAREST + ) + assert ( + segmentation_mask is None + or image.shape[:2] == segmentation_mask.shape[:2] + ) + + # Preprocess and feed-forward image. + image_preprocessed = self._preprocess_image(image) + out = self.net.run( + None, + { + self.net.get_inputs()[0].name: np.expand_dims( + image_preprocessed, 0 + ) + }, + ) + # Format output from network. + keypoint_map = np.squeeze(out[5]) + descriptor_map = np.squeeze(out[0]) + if segmentation_mask is not None: + keypoint_map = np.where(segmentation_mask, keypoint_map, 0.0) + keypoints, descriptors, scores = self._extract_superpoint_output( + keypoint_map, descriptor_map, num_features, pad_random_features + ) + + # Rescale keypoint locations to match original input image size, and return. + keypoints = keypoints / keypoint_scale_factors + return (keypoints, descriptors, scores) + + def _resize_input_image(self, image, interpolation=cv2.INTER_LINEAR): + """Resizes image such that both dimensions are divisble by 8.""" + + # Calculate new image dimensions and per-dimension resizing scale factor. + new_dim = [-1, -1] + keypoint_scale_factors = [1.0, 1.0] + for i in range(2): + dim_size = image.shape[i] + mod_eight = dim_size % 8 + if mod_eight < 4: + # Round down to nearest multiple of 8. + new_dim[i] = dim_size - mod_eight + elif mod_eight >= 4: + # Round up to nearest multiple of 8. + new_dim[i] = dim_size + (8 - mod_eight) + keypoint_scale_factors[i] = (new_dim[i] - 1) / (dim_size - 1) + + # Resize and return image + scale factors. + new_dim = new_dim[::-1] # Convert from (row, col) to (x,y). + keypoint_scale_factors = keypoint_scale_factors[::-1] + image = cv2.resize(image, tuple(new_dim), interpolation=interpolation) + return image, keypoint_scale_factors + + def _preprocess_image(self, image): + """Converts image to grayscale and normalizes values for model input.""" + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + image = np.expand_dims(image, 2) + image = image.astype(np.float32) + image = image / 255.0 + return image + + def _extract_superpoint_output( + self, + keypoint_map, + descriptor_map, + keep_k_points=512, + pad_random_features=False, + ): + """Converts from raw SuperPoint output (feature maps) into numpy arrays. + + If keep_k_points is 0, then keep all detected keypoints. Otherwise, sort by + confidence and keep only the top k confidence keypoints. + + Args: + keypoint_map: (H, W, 1) numpy array, raw output confidence values from + SuperPoint model. + descriptor_map: (H, W, 256) numpy array, raw output descriptors from + SuperPoint model. + keep_k_points: int, number of keypoints to keep (or 0 to indicate keeping + all detected keypoints). + pad_random_features: if True, adds randomly sampled keypoints to the + output such that there are exactly 'num_features' keypoints. Descriptors + for these sampled keypoints are taken from the network's descriptor map + output, and scores are set to 0. No action taken if keep_k_points = 0. + + Returns: + keypoints: (N, 2) numpy array, image coordinates (x, y) of keypoints as + floats. + descriptors: (N, 256) numpy array, descriptors for keypoints as floats. + scores: (N, 1) numpy array, confidence values for keypoints as floats. + """ + + def _select_k_best(points, k): + sorted_prob = points[points[:, 2].argsort(), :] + start = min(k, points.shape[0]) + return sorted_prob[-start:, :2], sorted_prob[-start:, 2] + + keypoints = np.where(keypoint_map > 0) + prob = keypoint_map[keypoints[0], keypoints[1]] + keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1) + + # Keep only top k points, or all points if keep_k_points param is 0. + if keep_k_points == 0: + keep_k_points = keypoints.shape[0] + keypoints, scores = _select_k_best(keypoints, keep_k_points) + + # Optionally, pad with random features (and confidence scores of 0). + image_shape = np.array(keypoint_map.shape[:2]) + if pad_random_features and (keep_k_points > keypoints.shape[0]): + num_pad = keep_k_points - keypoints.shape[0] + keypoints_pad = (image_shape - 1) * np.random.uniform( + size=(num_pad, 2) + ) + keypoints = np.concatenate((keypoints, keypoints_pad)) + scores_pad = np.zeros((num_pad)) + scores = np.concatenate((scores, scores_pad)) + + # Lookup descriptors via bilinear interpolation. + # TODO: batch descriptor lookup with bilinear interpolation. + keypoints[:, [0, 1]] = keypoints[ + :, [1, 0] + ] # Swap from (row,col) to (x,y). + descriptors = [] + for kp in keypoints: + descriptors.append( + utils.lookup_descriptor_bilinear(kp, descriptor_map) + ) + descriptors = np.array(descriptors) + return keypoints, descriptors, scores diff --git a/imcui/third_party/omniglue/src/omniglue/utils.py b/imcui/third_party/omniglue/src/omniglue/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ecb40b9d9f5dfc4f3fe9c83efd801495cc7029 --- /dev/null +++ b/imcui/third_party/omniglue/src/omniglue/utils.py @@ -0,0 +1,282 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utility functions for OmniGlue.""" +import cv2 +import torch +import math +import numpy as np +from typing import Optional + + +def lookup_descriptor_bilinear( + keypoint: np.ndarray, descriptor_map: np.ndarray +) -> np.ndarray: + """Looks up descriptor value for keypoint from a dense descriptor map. + + Uses bilinear interpolation to find descriptor value at non-integer + positions. + + Args: + keypoint: 2-dim numpy array containing (x, y) keypoint image coordinates. + descriptor_map: (H, W, D) numpy array representing a dense descriptor map. + + Returns: + D-dim descriptor value at the input 'keypoint' location. + + Raises: + ValueError, if kepoint position is out of bounds. + """ + height, width = descriptor_map.shape[:2] + if ( + keypoint[0] < 0 + or keypoint[0] > width + or keypoint[1] < 0 + or keypoint[1] > height + ): + raise ValueError( + "Keypoint position (%f, %f) is out of descriptor map bounds (%i w x" + " %i h)." % (keypoint[0], keypoint[1], width, height) + ) + + x_range = [math.floor(keypoint[0])] + if not keypoint[0].is_integer() and keypoint[0] < width - 1: + x_range.append(x_range[0] + 1) + y_range = [math.floor(keypoint[1])] + if not keypoint[1].is_integer() and keypoint[1] < height - 1: + y_range.append(y_range[0] + 1) + + bilinear_descriptor = np.zeros(descriptor_map.shape[2]) + for curr_x in x_range: + for curr_y in y_range: + curr_descriptor = descriptor_map[curr_y, curr_x, :] + bilinear_scalar = (1.0 - abs(keypoint[0] - curr_x)) * ( + 1.0 - abs(keypoint[1] - curr_y) + ) + bilinear_descriptor += bilinear_scalar * curr_descriptor + return bilinear_descriptor + + +def soft_assignment_to_match_matrix( + soft_assignment: torch.Tensor, match_threshold: float +) -> torch.Tensor: + """Converts a matrix of soft assignment values to binary yes/no match matrix. + + Searches soft_assignment for row- and column-maximum values, which indicate + mutual nearest neighbor matches between two unique sets of keypoints. Also, + ensures that score values for matches are above the specified threshold. + + Args: + soft_assignment: (B, N, M) tensor, contains matching likelihood value + between features of different sets. N is number of features in image0, and + M is number of features in image1. Higher value indicates more likely to + match. + match_threshold: float, thresholding value to consider a match valid. + + Returns: + (B, N, M) tensor of binary values. A value of 1 at index (x, y) indicates + a match between index 'x' (out of N) in image0 and index 'y' (out of M) in + image 1. + """ + + def _range_like(x, dim): + return torch.arange(x.shape[dim], dtype=x.dtype) + + matches = [] + for i in range(soft_assignment.shape[0]): + scores = soft_assignment[i, :].unsqueeze(0) + + max0 = torch.max(scores, dim=2)[0] + indices0 = torch.argmax(scores, dim=2) + indices1 = torch.argmax(scores, dim=1) + + mutual = _range_like(indices0, 1).unsqueeze(0) == indices1.gather( + 1, indices0 + ) + + kp_ind_pairs = torch.stack( + [_range_like(indices0, 1), indices0.squeeze()], dim=1 + ) + mutual_max0 = torch.where( + mutual, max0, torch.zeros_like(max0) + ).squeeze() + sparse = torch.sparse_coo_tensor( + kp_ind_pairs.t(), mutual_max0, scores.shape[1:] + ) + match_matrix = sparse.to_dense() + matches.append(match_matrix) + + match_matrix = torch.stack(matches) + match_matrix = match_matrix > match_threshold + return match_matrix + + +def visualize_matches( + image0: np.ndarray, + image1: np.ndarray, + kp0: np.ndarray, + kp1: np.ndarray, + match_matrix: np.ndarray, + match_labels: Optional[np.ndarray] = None, + show_keypoints: bool = False, + highlight_unmatched: bool = False, + title: Optional[str] = None, + line_width: int = 1, + circle_radius: int = 4, + circle_thickness: int = 2, + rng: Optional["np.random.Generator"] = None, +): + """Generates visualization of keypoints and matches for two images. + + Stacks image0 and image1 horizontally. In case the two images have different + heights, scales image1 (and its keypoints) to match image0's height. Note + that keypoints must be in (x, y) format, NOT (row, col). If match_matrix + includes unmatched dustbins, the dustbins will be removed before visualizing + matches. + + Args: + image0: (H, W, 3) array containing image0 contents. + image1: (H, W, 3) array containing image1 contents. + kp0: (N, 2) array where each row represents (x, y) coordinates of keypoints + in image0. + kp1: (M, 2) array, where each row represents (x, y) coordinates of keypoints + in image1. + match_matrix: (N, M) binary array, where values are non-zero for keypoint + indices making up a match. + match_labels: (N, M) binary array, where values are non-zero for keypoint + indices making up a ground-truth match. When None, matches from + 'match_matrix' are colored randomly. Otherwise, matches from + 'match_matrix' are colored according to accuracy (compared to labels). + show_keypoints: if True, all image0 and image1 keypoints (including + unmatched ones) are visualized. + highlight_unmatched: if True, highlights unmatched keypoints in blue. + title: if not None, adds title text to top left of visualization. + line_width: width of correspondence line, in pixels. + circle_radius: radius of keypoint circles, if visualized. + circle_thickness: thickness of keypoint circles, if visualized. + rng: np random number generator to generate the line colors. + + Returns: + Numpy array of image0 and image1 side-by-side, with lines between matches + according to match_matrix. If show_keypoints is True, keypoints from both + images are also visualized. + """ + # initialize RNG + if rng is None: + rng = np.random.default_rng() + + # Make copy of input param that may be modified in this function. + kp1 = np.copy(kp1) + + # Detect unmatched dustbins. + has_unmatched_dustbins = (match_matrix.shape[0] == kp0.shape[0] + 1) and ( + match_matrix.shape[1] == kp1.shape[0] + 1 + ) + + # If necessary, resize image1 so that the pair can be stacked horizontally. + height0 = image0.shape[0] + height1 = image1.shape[0] + if height0 != height1: + scale_factor = height0 / height1 + if scale_factor <= 1.0: + interp_method = cv2.INTER_AREA + else: + interp_method = cv2.INTER_LINEAR + new_dim1 = (int(image1.shape[1] * scale_factor), height0) + image1 = cv2.resize(image1, new_dim1, interpolation=interp_method) + kp1 *= scale_factor + + # Create side-by-side image and add lines for all matches. + viz = cv2.hconcat([image0, image1]) + w0 = image0.shape[1] + matches = np.argwhere( + match_matrix[:-1, :-1] if has_unmatched_dustbins else match_matrix + ) + for match in matches: + mpt0 = kp0[match[0]] + mpt1 = kp1[match[1]] + if isinstance(mpt0, torch.Tensor): + mpt0 = mpt0.numpy() + if isinstance(mpt1, torch.Tensor): + mpt1 = mpt1.numpy() + pt0 = (int(mpt0[0]), int(mpt0[1])) + pt1 = (int(mpt1[0] + w0), int(mpt1[1])) + if match_labels is None: + color = tuple(rng.integers(0, 255, size=3).tolist()) + else: + if match_labels[match[0], match[1]]: + color = (0, 255, 0) + else: + color = (255, 0, 0) + cv2.line(viz, pt0, pt1, color, line_width) + + # Optionally, add circles to output image to represent each keypoint. + if show_keypoints: + for i in range(np.shape(kp0)[0]): + kp = kp0[i].numpy() if isinstance(kp0[i], torch.Tensor) else kp0[i] + if ( + highlight_unmatched + and has_unmatched_dustbins + and match_matrix[i, -1] + ): + cv2.circle( + viz, + tuple(kp.astype(np.int32).tolist()), + circle_radius, + (255, 0, 0), + circle_thickness, + ) + else: + cv2.circle( + viz, + tuple(kp.astype(np.int32).tolist()), + circle_radius, + (0, 0, 255), + circle_thickness, + ) + for j in range(np.shape(kp1)[0]): + kp = kp1[j].numpy() if isinstance(kp1[j], torch.Tensor) else kp1[j] + kp[0] += w0 + if ( + highlight_unmatched + and has_unmatched_dustbins + and match_matrix[-1, j] + ): + cv2.circle( + viz, + tuple(kp.astype(np.int32).tolist()), + circle_radius, + (255, 0, 0), + circle_thickness, + ) + else: + cv2.circle( + viz, + tuple(kp.astype(np.int32).tolist()), + circle_radius, + (0, 0, 255), + circle_thickness, + ) + if title is not None: + viz = cv2.putText( + viz, + title, + (5, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 255), + 2, + cv2.LINE_AA, + ) + return viz diff --git a/imcui/third_party/omniglue/third_party/dinov2/__init__.py b/imcui/third_party/omniglue/third_party/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/omniglue/third_party/dinov2/dino.py b/imcui/third_party/omniglue/third_party/dinov2/dino.py new file mode 100644 index 0000000000000000000000000000000000000000..793d630c4d61a9c30135139bdc144fca389d9e3a --- /dev/null +++ b/imcui/third_party/omniglue/third_party/dinov2/dino.py @@ -0,0 +1,411 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +from typing import Callable, Sequence, Tuple, Union + +from . import dino_utils +import torch +from torch import nn +from torch.nn.init import trunc_normal_ +import torch.utils.checkpoint + + +def named_apply( + fn: Callable, + module: nn.Module, + name="", + depth_first=True, + include_root=False, +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + + def __init__( + self, + img_size=518, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=dino_utils.PatchEmbed, + act_layer=nn.GELU, + block_fn=dino_utils.Block, + ffn_layer="mlp", + block_chunks=0, + ): + """Args: + + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for + FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + ffn_layer = dino_utils.Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + # ffn_layer = SwiGLUFFNFused + raise NotImplementedError("FFN only support mlp but using swiglu") + elif ffn_layer == "identity": + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape( + 1, int(math.sqrt(N)), int(math.sqrt(N)), dim + ).permute(0, 3, 1, 2), + size=None, + scale_factor=[w0 / math.sqrt(N), h0 / math.sqrt(N)], + mode="bicubic", + ) + + assert ( + int(w0) == patch_pos_embed.shape[-2] + and int(h0) == patch_pos_embed.shape[-1] + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append({ + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + }) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.get_intermediate_layers( + x, n=1, reshape=True, return_class_token=False, norm=True + )[0] + + # def forward(self, *args, is_training=False, **kwargs): + # ret = self.forward_features(*args, **kwargs) + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=14, **kwargs): + model = DinoVisionTransformer( + img_size=518, + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + init_values=1e-5, + block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=14, **kwargs): + model = DinoVisionTransformer( + img_size=518, + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + init_values=1e-5, + block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=14, **kwargs): + model = DinoVisionTransformer( + img_size=518, + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + init_values=1e-5, + block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=14, **kwargs): + """Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64""" + model = DinoVisionTransformer( + img_size=518, + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + init_values=1e-5, + block_fn=partial(dino_utils.Block, attn_class=dino_utils.MemEffAttention), + **kwargs, + ) + return model diff --git a/imcui/third_party/omniglue/third_party/dinov2/dino_utils.py b/imcui/third_party/omniglue/third_party/dinov2/dino_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71768b27c45ac82cd3162184c60fa6e4dd7fd0eb --- /dev/null +++ b/imcui/third_party/omniglue/third_party/dinov2/dino_utils.py @@ -0,0 +1,341 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. +# +# References: +# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/eval/segmentation_m2f/models/backbones/vit.py + +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn + + +class Mlp(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + else: + raise NotImplementedError("MemEffAttention do not support xFormer") + # B, N, C = x.shape + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + # q, k, v = unbind(qkv, 2) + + # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + # x = x.reshape([B, N, C]) + + # x = self.proj(x) + # x = self.proj_drop(x) + # return x + + +class LayerScale(nn.Module): + + def __init__( + self, + dim: int, + init_values: Union[float, torch.Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Block(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) + if init_values + else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) + if init_values + else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def attn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: torch.Tensor, + residual_func: Callable[[torch.Tensor], torch.Tensor], + sample_drop_ratio: float = 0.0, +) -> torch.Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) diff --git a/imcui/third_party/pram/colmap_utils/camera_intrinsics.py b/imcui/third_party/pram/colmap_utils/camera_intrinsics.py new file mode 100644 index 0000000000000000000000000000000000000000..41bdc5055dfb451fa1f4dac3f27931675b68333f --- /dev/null +++ b/imcui/third_party/pram/colmap_utils/camera_intrinsics.py @@ -0,0 +1,30 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> camera_intrinsics +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 15/08/2023 12:33 +==================================================''' +import numpy as np + + +def intrinsics_from_camera(camera_model, params): + if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K diff --git a/imcui/third_party/pram/colmap_utils/database.py b/imcui/third_party/pram/colmap_utils/database.py new file mode 100644 index 0000000000000000000000000000000000000000..37638347834f4b0b1432846adf9a83693b509a7f --- /dev/null +++ b/imcui/third_party/pram/colmap_utils/database.py @@ -0,0 +1,352 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +# This script is based on an original implementation by True Price. + +import sys +import sqlite3 +import numpy as np + + +IS_PYTHON3 = sys.version_info[0] >= 3 + +MAX_IMAGE_ID = 2**31 - 1 + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) +""".format(MAX_IMAGE_ID) + +CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ +CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB) +""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) +""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = \ + "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join([ + CREATE_CAMERAS_TABLE, + CREATE_IMAGES_TABLE, + CREATE_KEYPOINTS_TABLE, + CREATE_DESCRIPTORS_TABLE, + CREATE_MATCHES_TABLE, + CREATE_TWO_VIEW_GEOMETRIES_TABLE, + CREATE_NAME_INDEX +]) + + +def image_ids_to_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def pair_id_to_image_ids(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID + return image_id1, image_id2 + + +def array_to_blob(array): + if IS_PYTHON3: + return array.tostring() + else: + return np.getbuffer(array) + + +def blob_to_array(blob, dtype, shape=(-1,)): + if IS_PYTHON3: + return np.fromstring(blob, dtype=dtype).reshape(*shape) + else: + return np.frombuffer(blob, dtype=dtype).reshape(*shape) + + +class COLMAPDatabase(sqlite3.Connection): + + @staticmethod + def connect(database_path): + return sqlite3.connect(str(database_path), factory=COLMAPDatabase) + + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.create_tables = lambda: self.executescript(CREATE_ALL) + self.create_cameras_table = \ + lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.create_descriptors_table = \ + lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) + self.create_images_table = \ + lambda: self.executescript(CREATE_IMAGES_TABLE) + self.create_two_view_geometries_table = \ + lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) + self.create_keypoints_table = \ + lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.create_matches_table = \ + lambda: self.executescript(CREATE_MATCHES_TABLE) + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + def add_camera(self, model, width, height, params, + prior_focal_length=False, camera_id=None): + params = np.asarray(params, np.float64) + cursor = self.execute( + "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + (camera_id, model, width, height, array_to_blob(params), + prior_focal_length)) + return cursor.lastrowid + + def add_image(self, name, camera_id, + prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None): + cursor = self.execute( + "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], + prior_q[3], prior_t[0], prior_t[1], prior_t[2])) + return cursor.lastrowid + + def add_keypoints(self, image_id, keypoints): + assert(len(keypoints.shape) == 2) + assert(keypoints.shape[1] in [2, 4, 6]) + + keypoints = np.asarray(keypoints, np.float32) + self.execute( + "INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) + + def add_descriptors(self, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + self.execute( + "INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) + + def add_matches(self, image_id1, image_id2, matches): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + self.execute( + "INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),)) + + def add_two_view_geometry(self, image_id1, image_id2, matches, + F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + F = np.asarray(F, dtype=np.float64) + E = np.asarray(E, dtype=np.float64) + H = np.asarray(H, dtype=np.float64) + self.execute( + "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches), config, + array_to_blob(F), array_to_blob(E), array_to_blob(H))) + + +def example_usage(): + import os + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--database_path", default="database.db") + args = parser.parse_args() + + if os.path.exists(args.database_path): + print("ERROR: database path already exists -- will not modify it.") + return + + # Open the database. + + db = COLMAPDatabase.connect(args.database_path) + + # For convenience, try creating all the tables upfront. + + db.create_tables() + + # Create dummy cameras. + + model1, width1, height1, params1 = \ + 0, 1024, 768, np.array((1024., 512., 384.)) + model2, width2, height2, params2 = \ + 2, 1024, 768, np.array((1024., 512., 384., 0.1)) + + camera_id1 = db.add_camera(model1, width1, height1, params1) + camera_id2 = db.add_camera(model2, width2, height2, params2) + + # Create dummy images. + + image_id1 = db.add_image("image1.png", camera_id1) + image_id2 = db.add_image("image2.png", camera_id1) + image_id3 = db.add_image("image3.png", camera_id2) + image_id4 = db.add_image("image4.png", camera_id2) + + # Create dummy keypoints. + # + # Note that COLMAP supports: + # - 2D keypoints: (x, y) + # - 4D keypoints: (x, y, theta, scale) + # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) + + num_keypoints = 1000 + keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) + keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) + + db.add_keypoints(image_id1, keypoints1) + db.add_keypoints(image_id2, keypoints2) + db.add_keypoints(image_id3, keypoints3) + db.add_keypoints(image_id4, keypoints4) + + # Create dummy matches. + + M = 50 + matches12 = np.random.randint(num_keypoints, size=(M, 2)) + matches23 = np.random.randint(num_keypoints, size=(M, 2)) + matches34 = np.random.randint(num_keypoints, size=(M, 2)) + + db.add_matches(image_id1, image_id2, matches12) + db.add_matches(image_id2, image_id3, matches23) + db.add_matches(image_id3, image_id4, matches34) + + # Commit the data to the file. + + db.commit() + + # Read and check cameras. + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id1 + assert model == model1 and width == width1 and height == height1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id2 + assert model == model2 and width == width2 and height == height2 + assert np.allclose(params, params2) + + # Read and check keypoints. + + keypoints = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute( + "SELECT image_id, data FROM keypoints")) + + assert np.allclose(keypoints[image_id1], keypoints1) + assert np.allclose(keypoints[image_id2], keypoints2) + assert np.allclose(keypoints[image_id3], keypoints3) + assert np.allclose(keypoints[image_id4], keypoints4) + + # Read and check matches. + + pair_ids = [image_ids_to_pair_id(*pair) for pair in + ((image_id1, image_id2), + (image_id2, image_id3), + (image_id3, image_id4))] + + matches = dict( + (pair_id_to_image_ids(pair_id), + blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches") + ) + + assert np.all(matches[(image_id1, image_id2)] == matches12) + assert np.all(matches[(image_id2, image_id3)] == matches23) + assert np.all(matches[(image_id3, image_id4)] == matches34) + + # Clean up. + + db.close() + + if os.path.exists(args.database_path): + os.remove(args.database_path) + + +if __name__ == "__main__": + example_usage() \ No newline at end of file diff --git a/imcui/third_party/pram/colmap_utils/geometry.py b/imcui/third_party/pram/colmap_utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..0d48f0a9545f04300f0f914515e650bb60957296 --- /dev/null +++ b/imcui/third_party/pram/colmap_utils/geometry.py @@ -0,0 +1,17 @@ +# -*- coding: UTF-8 -*- +import numpy as np +import pycolmap + + +def to_homogeneous(p): + return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1) + + +def compute_epipolar_errors(j_from_i: pycolmap.Rigid3d, p2d_i, p2d_j): + j_E_i = j_from_i.essential_matrix() + l2d_j = to_homogeneous(p2d_i) @ j_E_i.T + l2d_i = to_homogeneous(p2d_j) @ j_E_i + dist = np.abs(np.sum(to_homogeneous(p2d_i) * l2d_i, axis=1)) + errors_i = dist / np.linalg.norm(l2d_i[:, :2], axis=1) + errors_j = dist / np.linalg.norm(l2d_j[:, :2], axis=1) + return errors_i, errors_j diff --git a/imcui/third_party/pram/colmap_utils/io.py b/imcui/third_party/pram/colmap_utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad46c685ca2a2fbb166d22884948f3fd6547368 --- /dev/null +++ b/imcui/third_party/pram/colmap_utils/io.py @@ -0,0 +1,78 @@ +# -*- coding: UTF-8 -*- +from pathlib import Path +from typing import Tuple + +import cv2 +import h5py +import numpy as np + +from .parsers import names_to_pair, names_to_pair_old + + +def read_image(path, grayscale=False): + if grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise ValueError(f"Cannot read image {path}.") + if not grayscale and len(image.shape) == 3: + image = image[:, :, ::-1] # BGR to RGB + return image + + +def list_h5_names(path): + names = [] + with h5py.File(str(path), "r", libver="latest") as fd: + def visit_fn(_, obj): + if isinstance(obj, h5py.Dataset): + names.append(obj.parent.name.strip("/")) + + fd.visititems(visit_fn) + return list(set(names)) + + +def get_keypoints( + path: Path, name: str, return_uncertainty: bool = False +) -> np.ndarray: + with h5py.File(str(path), "r", libver="latest") as hfile: + dset = hfile[name]["keypoints"] + p = dset.__array__() + uncertainty = dset.attrs.get("uncertainty") + if return_uncertainty: + return p, uncertainty + return p + + +def find_pair(hfile: h5py.File, name0: str, name1: str): + pair = names_to_pair(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair(name1, name0) + if pair in hfile: + return pair, True + # older, less efficient format + pair = names_to_pair_old(name0, name1) + if pair in hfile: + return pair, False + pair = names_to_pair_old(name1, name0) + if pair in hfile: + return pair, True + raise ValueError( + f"Could not find pair {(name0, name1)}... " + "Maybe you matched with a different list of pairs? " + ) + + +def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]: + with h5py.File(str(path), "r", libver="latest") as hfile: + pair, reverse = find_pair(hfile, name0, name1) + matches = hfile[pair]["matches0"].__array__() + scores = hfile[pair]["matching_scores0"].__array__() + idx = np.where(matches != -1)[0] + matches = np.stack([idx, matches[idx]], -1) + if reverse: + matches = np.flip(matches, -1) + scores = scores[idx] + return matches, scores diff --git a/imcui/third_party/pram/colmap_utils/parsers.py b/imcui/third_party/pram/colmap_utils/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9087d78cc8cf7f1e81ab8359862227c3882786 --- /dev/null +++ b/imcui/third_party/pram/colmap_utils/parsers.py @@ -0,0 +1,73 @@ +# -*- coding: UTF-8 -*- + +from pathlib import Path +import logging +import numpy as np +from collections import defaultdict + + +def parse_image_lists_with_intrinsics(paths): + results = [] + files = list(Path(paths.parent).glob(paths.name)) + assert len(files) > 0 + + for lfile in files: + with open(lfile, 'r') as f: + raw_data = f.readlines() + + logging.info(f'Importing {len(raw_data)} queries in {lfile.name}') + for data in raw_data: + data = data.strip('\n').split(' ') + name, camera_model, width, height = data[:4] + params = np.array(data[4:], float) + info = (camera_model, int(width), int(height), params) + results.append((name, info)) + + assert len(results) > 0 + return results + + +def parse_img_lists_for_extended_cmu_seaons(paths): + Ks = { + "c0": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571", + "c1": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571" + } + + results = [] + files = list(Path(paths.parent).glob(paths.name)) + assert len(files) > 0 + + for lfile in files: + with open(lfile, 'r') as f: + raw_data = f.readlines() + + logging.info(f'Importing {len(raw_data)} queries in {lfile.name}') + for name in raw_data: + name = name.strip('\n') + camera = name.split('_')[2] + K = Ks[camera].split(' ') + camera_model, width, height = K[:3] + params = np.array(K[3:], float) + # print("camera: ", camera_model, width, height, params) + info = (camera_model, int(width), int(height), params) + results.append((name, info)) + + assert len(results) > 0 + return results + + +def parse_retrieval(path): + retrieval = defaultdict(list) + with open(path, 'r') as f: + for p in f.read().rstrip('\n').split('\n'): + q, r = p.split(' ') + retrieval[q].append(r) + return dict(retrieval) + + +def names_to_pair_old(name0, name1): + return '_'.join((name0.replace('/', '-'), name1.replace('/', '-'))) + + +def names_to_pair(name0, name1, separator="/"): + return separator.join((name0.replace("/", "-"), name1.replace("/", "-"))) diff --git a/imcui/third_party/pram/colmap_utils/read_write_model.py b/imcui/third_party/pram/colmap_utils/read_write_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eddbeb7edd364c27c54029fa81077ea4f75d2700 --- /dev/null +++ b/imcui/third_party/pram/colmap_utils/read_write_model.py @@ -0,0 +1,627 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +import os +import sys +import collections +import numpy as np +import struct +import argparse + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for camera_line_index in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8 * num_params, + format_char_sequence="d" * num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = '# Camera list with one line of data per camera:\n' + '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n' + '# Number of cameras: {}\n'.format(len(cameras)) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, + model_id, + cam.width, + cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images) + HEADER = '# Image list with two lines of data per image:\n' + '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' + '# POINTS2D[] as (X, Y, POINT3D_ID)\n' + '# Number of images: {}, mean observations per image: {}\n'.format(len(images), mean_observations) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8 * track_length, + format_char_sequence="ii" * track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D) + HEADER = '# 3D point list with one line of data per point:\n' + '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n' + '# Number of points: {}, mean track length: {}\n'.format(len(points3D), mean_track_length) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3d_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def read_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def read_compressed_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + # x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, + # format_char_sequence="ddq" * num_points2D) + # xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + # tuple(map(float, x_y_id_s[1::3]))]) + x_y_id_s = read_next_bytes(fid, num_bytes=8 * num_points2D, + format_char_sequence="q" * num_points2D) + point3D_ids = np.array(x_y_id_s) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=np.array([]), point3D_ids=point3D_ids) + return images + + +def write_compressed_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for p3d_id in img.point3D_ids: + write_next_bytes(fid, p3d_id, "q") + # for xy, p3d_id in zip(img.xys, img.point3D_ids): + # write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_compressed_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=4 * track_length, + format_char_sequence="i" * track_length) + image_ids = np.array(track_elems) + # point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=np.array([])) + return points3D + + +def write_compressed_points3d_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + # for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + # write_next_bytes(fid, [image_id, point2D_id], "ii") + for image_id in pt.image_ids: + write_next_bytes(fid, image_id, "i") + + +def read_compressed_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_compressed_images_binary(os.path.join(path, "images" + ext)) + points3D = read_compressed_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]]) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def intrinsics_from_camera(camera_model, params): + if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K + + +def main(): + parser = argparse.ArgumentParser(description='Read and write COLMAP binary and text models') + parser.add_argument('input_model', help='path to input model folder') + parser.add_argument('input_format', choices=['.bin', '.txt'], + help='input model format') + parser.add_argument('--output_model', metavar='PATH', + help='path to output model folder') + parser.add_argument('--output_format', choices=['.bin', '.txt'], + help='outut model format', default='.txt') + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) + + +if __name__ == "__main__": + main() diff --git a/imcui/third_party/pram/colmap_utils/utils.py b/imcui/third_party/pram/colmap_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d98fed2dfc5789b650144caa3a4bac8cfe6a2fb --- /dev/null +++ b/imcui/third_party/pram/colmap_utils/utils.py @@ -0,0 +1 @@ +# -*- coding: UTF-8 -*- diff --git a/imcui/third_party/pram/configs/config_train_12scenes_sfd2.yaml b/imcui/third_party/pram/configs/config_train_12scenes_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e6e7fb7c851edb8bd6e26e8d4806cadeb5977d5 --- /dev/null +++ b/imcui/third_party/pram/configs/config_train_12scenes_sfd2.yaml @@ -0,0 +1,102 @@ +dataset: [ '12Scenes' ] + +network_1: "segnet" +network: "segnetvit" + +local_rank: 0 +gpu: [ 0 ] + +feature: "sfd2" +save_path: '/scratches/flyer_2/fx221/exp/pram' +landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 +min_inliers: 32 +max_inliers: 512 +random_inliers: true +max_keypoints: 512 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +optimizer: "adamw" +seg_loss: "cew" +seg_loss_nx: "cei" +cls_loss: "ce" +cls_loss_: "bce" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 +do_eval: false + +use_mid_feature: true +norm_desc: false +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 60000 +epochs: 500 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230719_220620_segnet_L15_T_resnet4x_B32_K1024_relu_bn_od1024_nc193_adamw_cew_md_A_birch/segnet.499.pth' +weight_path_2: '20240202_145337_segnetvit_L15_T_resnet4x_B32_K512_relu_bn_od1024_nc193_adam_cew_md_A_birch/segnetvit.499.pth' + +resume_path: null + +n_class: 193 + +eval_max_keypoints: 1024 + +localization: + loc_scene_name: [ 'apt1/kitchen' ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + seg_k: 20 + threshold: 8 + min_kpts: 128 + min_matches: 4 + min_inliers: 64 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_5: "adagml" + save: false + show: true + show_time: 1 + max_vrf: 1 + with_original: true + with_extra: false + with_compress: true + semantic_matching: true + do_refinement: true + refinement_method_: 'matching' + refinement_method: 'projection' + pre_filtering_th: 0.95 + covisibility_frame: 20 + refinement_radius: 20 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/imcui/third_party/pram/configs/config_train_7scenes_sfd2.yaml b/imcui/third_party/pram/configs/config_train_7scenes_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19b0635c9ad4ebcf0a085a759640e4a149a75009 --- /dev/null +++ b/imcui/third_party/pram/configs/config_train_7scenes_sfd2.yaml @@ -0,0 +1,104 @@ +dataset: [ '7Scenes' ] + +network: "segnetvit" + +local_rank: 0 +gpu: [ 0 ] +# when using ddp, set gpu: [0,1,2,3] +with_dist: true + +feature: "sfd2" +save_path_: '/scratches/flyer_2/fx221/exp/pram' +save_path: '/scratches/flyer_2/fx221/publications/test_pram/exp' +landmark_path_: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +landmark_path: "/scratches/flyer_2/fx221/publications/test_pram/landmakrs/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 + +min_inliers: 32 +max_inliers: 256 +random_inliers: 1 +max_keypoints: 512 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +cls_loss: "ce" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 +do_eval: false + +use_mid_feature: true +norm_desc: false +with_cls: false +with_score: false +with_aug: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 80000 +epochs: 200 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230724_203230_segnet_L15_S_resnet4x_B32_K1024_relu_bn_od1024_nc113_adam_cew_md_A_birch/segnet.180.pth' +weight_path_2: '20240202_152519_segnetvit_L15_S_resnet4x_B32_K512_relu_bn_od1024_nc113_adamw_cew_md_A_birch/segnetvit.199.pth' + +# used for resuming training +resume_path: null + +# used for localization +n_class: 113 + +eval_max_keypoints: 1024 + +localization: + loc_scene_name: [ 'chess' ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + + seg_k: 20 + threshold: 8 + min_kpts: 128 + min_matches: 16 + min_inliers: 32 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_4: "adagml" + save: false + show: true + show_time: 1 + with_original: true + max_vrf: 1 + with_compress: true + semantic_matching: true + do_refinement: true + pre_filtering_th: 0.95 + refinement_method_: 'matching' + refinement_method: 'projection' + covisibility_frame: 20 + refinement_radius: 20 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/imcui/third_party/pram/configs/config_train_aachen_sfd2.yaml b/imcui/third_party/pram/configs/config_train_aachen_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e2111377ed9d6cff38efd69bc397487ecfb33fb --- /dev/null +++ b/imcui/third_party/pram/configs/config_train_aachen_sfd2.yaml @@ -0,0 +1,104 @@ +dataset: [ 'Aachen' ] + +network_: "segnet" +network: "segnetvit" +local_rank: 0 +gpu: [ 0 ] + +feature: "sfd2" +save_path: '/scratches/flyer_2/fx221/exp/pram' +landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" + +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 + +min_inliers: 32 +max_inliers: 512 +random_inliers: true +max_keypoints: 1024 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +do_eval: true +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +optimizer: "adam" +seg_loss: "cew" +seg_loss_nx: "cei" +cls_loss: "ce" +cls_loss_: "bce" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 + +use_mid_feature: true +norm_desc: false +with_sc: false +with_cls: true +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 80000 +epochs: 800 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230719_221442_segnet_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adamw_cew_md_A_birch/segnet.899.pth' +weight_path_2: '20240211_142623_segnetvit_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adam_cew_md_A_birch/segnetvit.799.pth' +resume_path: null + +n_class: 513 + +eval_max_keypoints: 4096 + +localization: + loc_scene_name: [ ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + seg_k: 10 + threshold: 12 + min_kpts: 256 + min_matches: 8 + min_inliers: 128 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_4: "adagml" + save: false + show: true + show_time: 1 + with_original: true + with_extra: false + max_vrf: 1 + with_compress: true + semantic_matching: true + refinement_method_: 'matching' + refinement_method: 'projection' + pre_filtering_th: 0.95 + do_refinement: true + covisibility_frame: 50 + refinement_radius: 30 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/imcui/third_party/pram/configs/config_train_cambridge_sfd2.yaml b/imcui/third_party/pram/configs/config_train_cambridge_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8cc843ee963dc5c0041954790d7e622e24aefe16 --- /dev/null +++ b/imcui/third_party/pram/configs/config_train_cambridge_sfd2.yaml @@ -0,0 +1,103 @@ +dataset: [ 'CambridgeLandmarks' ] + +network_: "segnet" +network: "segnetvit" + +local_rank: 0 +gpu: [ 0 ] + +feature: "sfd2" +save_path: '/scratches/flyer_2/fx221/exp/pram' +landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +feat_dim: 128 + +min_inliers: 32 +max_inliers: 512 +random_inliers: 1 +max_keypoints: 1024 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +do_eval: false +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +epochs: 300 +seg_loss: "cew" +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 + +use_mid_feature: true +norm_desc: false +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 60000 + +cluster_method: 'birch' + +weight_path: null +weight_path_1: '20230725_144044_segnet_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adam_cew_md_A_birch/segnet.260.pth' +weight_path_2: '20240204_130323_segnetvit_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adamw_cew_md_A_birch/segnetvit.399.pth' + +resume_path: null + +n_class: 161 + +eval_max_keypoints: 2048 + +localization: + loc_scene_name_1: [ 'GreatCourt' ] + loc_scene_name_2: [ 'KingsCollege' ] + loc_scene_name: [ 'StMarysChurch' ] + loc_scene_name_4: [ 'OldHospital' ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + seg_k: 30 + threshold: 12 + min_kpts: 256 + min_matches: 16 + min_inliers_gm: 128 + min_inliers: 128 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method_2: "gm" + matching_method: "gml" + matching_method_4: "adagml" + show: true + show_time: 1 + save: false + with_original: true + max_vrf: 1 + with_extra: false + with_compress: true + semantic_matching: true + do_refinement: true + pre_filtering_th: 0.95 + refinement_method_: 'matching' + refinement_method: 'projection' + covisibility_frame: 20 + refinement_radius: 20 + refinement_nn_ratio: 0.9 + refinement_max_matches: 0 diff --git a/imcui/third_party/pram/configs/config_train_multiset_sfd2.yaml b/imcui/third_party/pram/configs/config_train_multiset_sfd2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..90618e0812c2321ba05fbe3ab9a12d52ec447e99 --- /dev/null +++ b/imcui/third_party/pram/configs/config_train_multiset_sfd2.yaml @@ -0,0 +1,100 @@ +dataset: [ 'S', 'T', 'C', 'A' ] + +network: "segnet" +network_: "gsegnet3" + +local_rank: 0 +gpu: [ 4 ] + +feature: "resnet4x" +save_path: '/scratches/flyer_2/fx221/exp/localizer' +landmark_path: "/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gm" +dataset_path: "/scratches/flyer_3/fx221/dataset" +config_path: 'configs/datasets' + +image_dim: 3 +min_inliers: 32 +max_inliers: 512 +random_inliers: 1 +max_keypoints: 1024 +ignore_index: -1 +output_dim: 1024 +output_dim_: 2048 +jitter_params: + brightness: 0.5 + contrast: 0.5 + saturation: 0.25 + hue: 0.15 + blur: 0 + +scale_params: [ 0.5, 1.0 ] +pre_load: false +do_eval: true +train: true +inlier_th: 0.5 +lr: 0.0001 +min_lr: 0.00001 +optimizer: "adam" +seg_loss: "cew" +seg_loss_nx: "cei" +cls_loss: "ce" +cls_loss_: "bce" +sc_loss: 'l1g' +ac_fn: "relu" +norm_fn: "bn" +workers: 8 +layers: 15 +log_intervals: 50 +eval_n_epoch: 10 + +use_mid_feature: true +norm_desc: false +with_sc: false +with_cls: true +with_score: false +with_aug: true +with_dist: true + +batch_size: 32 +its_per_epoch: 1000 +decay_rate: 0.999992 +decay_iter: 150000 +epochs: 1500 + +cluster_method_: 'kmeans' +cluster_method: 'birch' + +weight_path_: null +weight_path: '20230805_132653_segnet_L15_STCA_resnet4x_B32_K1024_relu_bn_od1024_nc977_adam_cew_md_A_birch/segnet.485.pth' +resume_path: null + +eval: false +#loc: false +loc: true +#n_class: 977 +online: false + +eval_max_keypoints: 4096 + +localization: + loc_scene_name: [ ] + save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results' + dataset: [ 'T' ] + seg_k: 50 + threshold: 8 # 8 for indoor, 12 for outdoor + min_kpts: 256 + min_matches: 4 + min_inliers: 64 + matching_method_: "mnn" + matching_method_1: "spg" + matching_method: "gm" + save: false + show: true + show_time: 1 + do_refinement: true + with_original: true + with_extra: false + max_vrf: 1 + with_compress: false + covisibility_frame: 20 + observation_threshold: 3 diff --git a/imcui/third_party/pram/configs/datasets/12Scenes.yaml b/imcui/third_party/pram/configs/datasets/12Scenes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e950aca2ff25526af622fec779e9bb6a07eaea6b --- /dev/null +++ b/imcui/third_party/pram/configs/datasets/12Scenes.yaml @@ -0,0 +1,166 @@ +dataset: '12Scenes' +scenes: [ 'apt1/kitchen', + 'apt1/living', + 'apt2/bed', + 'apt2/kitchen', + 'apt2/living', + 'apt2/luke', + 'office1/gates362', + 'office1/gates381', + 'office1/lounge', + 'office1/manolis', + 'office2/5a', + 'office2/5b' +] + +apt1/kitchen: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + image_path_prefix: '' + + +apt1/living: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + +apt2/bed: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +apt2/kitchen: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +apt2/living: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +apt2/luke: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/gates362: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/gates381: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/lounge: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office1/manolis: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office2/5a: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +office2/5b: + n_cluster: 16 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 5 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' diff --git a/imcui/third_party/pram/configs/datasets/7Scenes.yaml b/imcui/third_party/pram/configs/datasets/7Scenes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd68181fbc0ed96ccb3e464d94a5346183c1dfe3 --- /dev/null +++ b/imcui/third_party/pram/configs/datasets/7Scenes.yaml @@ -0,0 +1,96 @@ +dataset: '7Scenes' +scenes: [ 'chess', 'heads', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin' ] + + +chess: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 2 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + + +heads: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 2 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + +office: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + +fire: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 2 + eval_sample_ratio: 5 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + +stairs: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + +redkitchen: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 3 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + + + + +pumpkin: + n_cluster: 16 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + + training_sample_ratio: 2 + eval_sample_ratio: 10 + gt_pose_path: 'queries_poses.txt' + query_path: 'queries_with_intrinsics.txt' + image_path_prefix: '' + diff --git a/imcui/third_party/pram/configs/datasets/Aachen.yaml b/imcui/third_party/pram/configs/datasets/Aachen.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49477afbe569cb0fc4317b6c1a98c30f261ee7e0 --- /dev/null +++ b/imcui/third_party/pram/configs/datasets/Aachen.yaml @@ -0,0 +1,15 @@ +dataset: 'Aachen' + +scenes: [ 'Aachenv11' ] + +Aachenv11: + n_cluster: 512 + cluster_mode: 'xz' + cluster_method_: 'kmeans' + cluster_method: 'birch' + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: 'images/images_upright' + query_path_: 'queries_with_intrinsics.txt' + query_path: 'queries_with_intrinsics_demo.txt' + gt_pose_path: 'queries_pose_spp_spg.txt' diff --git a/imcui/third_party/pram/configs/datasets/CambridgeLandmarks.yaml b/imcui/third_party/pram/configs/datasets/CambridgeLandmarks.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3a757898db1e772b593059d2c21ef1eaaa825ea --- /dev/null +++ b/imcui/third_party/pram/configs/datasets/CambridgeLandmarks.yaml @@ -0,0 +1,67 @@ +dataset: 'CambridgeLandmarks' +scenes: [ 'GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch' ] + +GreatCourt: + n_cluster: 32 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +KingsCollege: + n_cluster: 32 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +OldHospital: + n_cluster: 32 + cluster_mode: 'xz' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +ShopFacade: + n_cluster: 32 + cluster_mode: 'xy' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + +StMarysChurch: + n_cluster: 32 + cluster_mode: 'xz' + cluster_method: 'birch' + + training_sample_ratio: 1 + eval_sample_ratio: 1 + image_path_prefix: '' + + query_path: 'queries_with_intrinsics.txt' + gt_pose_path: 'queries_poses.txt' + + + diff --git a/imcui/third_party/pram/dataset/aachen.py b/imcui/third_party/pram/dataset/aachen.py new file mode 100644 index 0000000000000000000000000000000000000000..d57efd8e4460f943d66b2d8b92e57d7cd7f7f75a --- /dev/null +++ b/imcui/third_party/pram/dataset/aachen.py @@ -0,0 +1,119 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> aachen +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:33 +==================================================''' +import os.path as osp +import numpy as np +import cv2 +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class Aachen(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='Aachen', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = 'images/images_upright' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + self.img_fns = [] + if train: + with open(osp.join(self.dataset_path, 'aachen_db_imglist.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip() + if l not in self.name_to_id.keys(): + continue + self.img_fns.append(l) + else: + with open(osp.join(self.dataset_path, 'queries', 'day_time_queries_with_intrinsics.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split()[0] + if l not in self.img_p3d.keys(): + continue + self.img_fns.append(l) + with open(osp.join(self.dataset_path, 'queries', 'night_time_queries_with_intrinsics.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split()[0] + if l not in self.img_p3d.keys(): + continue + self.img_fns.append(l) + + print( + 'Load {} images from {} for {}...'.format(len(self.img_fns), self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + self.mean_xyz = np.array([float(v) for v in l[:3]]) + self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} + + def read_image(self, image_name): + return cv2.imread(osp.join(self.dataset_path, 'images/images_upright/', image_name)) diff --git a/imcui/third_party/pram/dataset/basicdataset.py b/imcui/third_party/pram/dataset/basicdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c77c32ca010e99d14ddd8643c2ff07789bd75851 --- /dev/null +++ b/imcui/third_party/pram/dataset/basicdataset.py @@ -0,0 +1,477 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> basicdataset +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:27 +==================================================''' +import torchvision.transforms.functional as tvf +import torchvision.transforms as tvt +import os.path as osp +import numpy as np +import cv2 +from colmap_utils.read_write_model import qvec2rotmat, read_model +from dataset.utils import normalize_size + + +class BasicDataset: + def __init__(self, + img_list_fn, + feature_dir, + sfm_path, + seg_fn, + dataset_path, + n_class, + dataset, + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=1, + pre_load=False, + query_info_path=None, + sc_mean_scale_fn=None, + ): + self.n_class = n_class + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.dataset_path = dataset_path + self.with_aug = with_aug + self.dataset = dataset + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + self.img_fns = [] + with open(img_list_fn, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip() + self.img_fns.append(l) + print('Load {} images from {} for {}...'.format(len(self.img_fns), dataset, 'training' if train else 'eval')) + self.feats = {} + if train: + self.cameras, self.images, point3Ds = read_model(path=sfm_path, ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items()} + + data = np.load(seg_fn, allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + self.p3d_xyzs = {} + + for pid in self.p3d_seg.keys(): + p3d = point3Ds[pid] + self.p3d_xyzs[pid] = p3d.xyz + + with open(sc_mean_scale_fn, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + self.mean_xyz = np.array([float(v) for v in l[:3]]) + self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = feature_dir + print('Pre loaded {} feats, mean xyz {}, scale xyz {}'.format(len(self.feats.keys()), self.mean_xyz, + self.scale_xyz)) + + def normalize_p3ds(self, p3ds): + mean_p3ds = np.ceil(np.mean(p3ds, axis=0)) + p3ds_ = p3ds - mean_p3ds + dx = np.max(abs(p3ds_[:, 0])) + dy = np.max(abs(p3ds_[:, 1])) + dz = np.max(abs(p3ds_[:, 2])) + scale_p3ds = np.ceil(np.array([dx, dy, dz], dtype=float).reshape(3, )) + scale_p3ds[scale_p3ds < 1] = 1 + scale_p3ds[scale_p3ds == 0] = 1 + return mean_p3ds, scale_p3ds + + def read_query_info(self, path): + query_info = {} + with open(path, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + image_name = l[0] + cam_model = l[1] + h, w = int(l[2]), int(l[3]) + params = np.array([float(v) for v in l[4:]]) + query_info[image_name] = { + 'width': w, + 'height': h, + 'model': cam_model, + 'params': params, + } + return query_info + + def extract_intrinsic_extrinsic_params(self, image_id): + cam = self.cameras[self.images[image_id].camera_id] + params = cam.params + model = cam.model + if model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + K = np.eye(3, dtype=float) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + + qvec = self.images[image_id].qvec + tvec = self.images[image_id].tvec + R = qvec2rotmat(qvec=qvec) + P = np.eye(4, dtype=float) + P[:3, :3] = R + P[:3, 3] = tvec.reshape(3, ) + + return {'K': K, 'P': P} + + def get_item_train(self, idx): + img_name = self.img_fns[idx] + if img_name in self.feats.keys(): + feat_data = self.feats[img_name] + else: + feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()] + # descs = feat_data['descriptors'] # [N, D] + scores = feat_data['scores'] # [N, 1] + kpts = feat_data['keypoints'] # [N, 2] + image_size = feat_data['image_size'] + + nfeat = kpts.shape[0] + + # print(img_name, self.name_to_id[img_name]) + p3d_ids = self.images[self.name_to_id[img_name]].point3D_ids + p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float) + + seg_ids = np.zeros(shape=(nfeat,), dtype=int) # + self.n_class - 1 + for i in range(nfeat): + p3d = p3d_ids[i] + if p3d in self.p3d_seg.keys(): + seg_ids[i] = self.p3d_seg[p3d] + 1 # 0 for invalid + if seg_ids[i] == -1: + seg_ids[i] = 0 + + if p3d in self.p3d_xyzs.keys(): + p3d_xyzs[i] = self.p3d_xyzs[p3d] + + seg_ids = np.array(seg_ids).reshape(-1, ) + + n_inliers = np.sum(seg_ids > 0) + n_outliers = np.sum(seg_ids == 0) + inlier_ids = np.where(seg_ids > 0)[0] + outlier_ids = np.where(seg_ids == 0)[0] + + if n_inliers <= self.min_inliers: + sel_inliers = n_inliers + sel_outliers = self.nfeatures - sel_inliers + + out_ids = np.arange(n_outliers) + np.random.shuffle(out_ids) + sel_ids = np.hstack([inlier_ids, outlier_ids[out_ids[:self.nfeatures - n_inliers]]]) + else: + sel_inliers = np.random.randint(self.min_inliers, self.max_inliers) + if sel_inliers > n_inliers: + sel_inliers = n_inliers + + if sel_inliers + n_outliers < self.nfeatures: + sel_inliers = self.nfeatures - n_outliers + + sel_outliers = self.nfeatures - sel_inliers + + in_ids = np.arange(n_inliers) + np.random.shuffle(in_ids) + sel_inlier_ids = inlier_ids[in_ids[:sel_inliers]] + + out_ids = np.arange(n_outliers) + np.random.shuffle(out_ids) + sel_outlier_ids = outlier_ids[out_ids[:sel_outliers]] + + sel_ids = np.hstack([sel_inlier_ids, sel_outlier_ids]) + + # sel_descs = descs[sel_ids] + sel_scores = scores[sel_ids] + sel_kpts = kpts[sel_ids] + sel_seg_ids = seg_ids[sel_ids] + sel_xyzs = p3d_xyzs[sel_ids] + + shuffle_ids = np.arange(sel_ids.shape[0]) + np.random.shuffle(shuffle_ids) + # sel_descs = sel_descs[shuffle_ids] + sel_scores = sel_scores[shuffle_ids] + sel_kpts = sel_kpts[shuffle_ids] + sel_seg_ids = sel_seg_ids[shuffle_ids] + sel_xyzs = sel_xyzs[shuffle_ids] + + if sel_kpts.shape[0] < self.nfeatures: + # print(sel_descs.shape, sel_kpts.shape, sel_scores.shape, sel_seg_ids.shape, sel_xyzs.shape) + valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0]) if sel_seg_ids[v] > 0], dtype=int) + # ref_sel_id = np.random.choice(valid_sel_ids, size=1)[0] + if valid_sel_ids.shape[0] == 0: + valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0])], dtype=int) + random_n = self.nfeatures - sel_kpts.shape[0] + random_scores = np.random.random((random_n,)) + random_kpts, random_seg_ids, random_xyzs = self.random_points_from_reference( + n=random_n, + ref_kpts=sel_kpts[valid_sel_ids], + ref_segs=sel_seg_ids[valid_sel_ids], + ref_xyzs=sel_xyzs[valid_sel_ids], + radius=5, + ) + # sel_descs = np.vstack([sel_descs, random_descs]) + sel_scores = np.hstack([sel_scores, random_scores]) + sel_kpts = np.vstack([sel_kpts, random_kpts]) + sel_seg_ids = np.hstack([sel_seg_ids, random_seg_ids]) + sel_xyzs = np.vstack([sel_xyzs, random_xyzs]) + + gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float) + uids = np.unique(sel_seg_ids).tolist() + for uid in uids: + if uid == 0: + continue + gt_cls[uid] = 1 + gt_n_seg[uid] = np.sum(sel_seg_ids == uid) + gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum(seg_ids > 0) # [valid_id / total_valid_id] + + param_out = self.extract_intrinsic_extrinsic_params(image_id=self.name_to_id[img_name]) + + img = self.read_image(image_name=img_name) + image_size = img.shape[:2] + if self.image_dim == 1: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if self.with_aug: + nh = img.shape[0] + nw = img.shape[1] + if self.scale_params is not None: + do_scale = np.random.random() + if do_scale <= 0.25: + p = np.random.randint(0, 11) + s = self.scale_params[0] + (self.scale_params[1] - self.scale_params[0]) / 10 * p + nh = int(img.shape[0] * s) + nw = int(img.shape[1] * s) + sh = nh / img.shape[0] + sw = nw / img.shape[1] + sel_kpts[:, 0] = sel_kpts[:, 0] * sw + sel_kpts[:, 1] = sel_kpts[:, 1] * sh + img = cv2.resize(img, dsize=(nw, nh)) + + brightness = np.random.uniform(-self.jitter_params['brightness'], self.jitter_params['brightness']) * 255 + contrast = 1 + np.random.uniform(-self.jitter_params['contrast'], self.jitter_params['contrast']) + img = cv2.addWeighted(img, contrast, img, 0, brightness) + img = np.clip(img, a_min=0, a_max=255) + if self.image_dim == 1: + img = img[..., None] + img = img.astype(float) / 255. + image_size = np.array([nh, nw], dtype=int) + else: + if self.image_dim == 1: + img = img[..., None].astype(float) / 255. + + output = { + # 'descriptors': sel_descs, # may not be used + 'scores': sel_scores, + 'keypoints': sel_kpts, + 'norm_keypoints': normalize_size(x=sel_kpts, size=image_size), + 'image': [img], + 'gt_seg': sel_seg_ids, + 'gt_cls': gt_cls, + 'gt_cls_dist': gt_cls_dist, + 'gt_n_seg': gt_n_seg, + 'file_name': img_name, + 'prefix_name': self.image_prefix, + # 'mean_xyz': self.mean_xyz, + # 'scale_xyz': self.scale_xyz, + # 'gt_sc': sel_xyzs, + # 'gt_norm_sc': (sel_xyzs - self.mean_xyz) / self.scale_xyz, + 'K': param_out['K'], + 'gt_P': param_out['P'] + } + return output + + def get_item_test(self, idx): + + # evaluation of recognition only + img_name = self.img_fns[idx] + feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()] + descs = feat_data['descriptors'] # [N, D] + scores = feat_data['scores'] # [N, 1] + kpts = feat_data['keypoints'] # [N, 2] + image_size = feat_data['image_size'] + + nfeat = descs.shape[0] + + if img_name in self.img_p3d.keys(): + p3d_ids = self.img_p3d[img_name] + p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float) + seg_ids = np.zeros(shape=(nfeat,), dtype=int) # attention! by default invalid!!! + for i in range(nfeat): + p3d = p3d_ids[i] + if p3d in self.p3d_seg.keys(): + seg_ids[i] = self.p3d_seg[p3d] + 1 + if seg_ids[i] == -1: + seg_ids[i] = 0 # 0 for in valid + + if p3d in self.p3d_xyzs.keys(): + p3d_xyzs[i] = self.p3d_xyzs[p3d] + + seg_ids = np.array(seg_ids).reshape(-1, ) + + if self.nfeatures > 0: + sorted_ids = np.argsort(scores)[::-1][:self.nfeatures] # large to small + descs = descs[sorted_ids] + scores = scores[sorted_ids] + kpts = kpts[sorted_ids] + p3d_xyzs = p3d_xyzs[sorted_ids] + + seg_ids = seg_ids[sorted_ids] + + gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float) + uids = np.unique(seg_ids).tolist() + for uid in uids: + if uid == 0: + continue + gt_cls[uid] = 1 + gt_n_seg[uid] = np.sum(seg_ids == uid) + gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum( + seg_ids < self.n_class - 1) # [valid_id / total_valid_id] + + gt_cls[0] = 0 + + img = self.read_image(image_name=img_name) + if self.image_dim == 1: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = img[..., None].astype(float) / 255. + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(float) / 255. + return { + 'descriptors': descs, + 'scores': scores, + 'keypoints': kpts, + 'image_size': image_size, + 'norm_keypoints': normalize_size(x=kpts, size=image_size), + 'gt_seg': seg_ids, + 'gt_cls': gt_cls, + 'gt_cls_dist': gt_cls_dist, + 'gt_n_seg': gt_n_seg, + 'file_name': img_name, + 'prefix_name': self.image_prefix, + 'image': [img], + + 'mean_xyz': self.mean_xyz, + 'scale_xyz': self.scale_xyz, + 'gt_sc': p3d_xyzs, + 'gt_norm_sc': (p3d_xyzs - self.mean_xyz) / self.scale_xyz + } + + def __getitem__(self, idx): + if self.train: + return self.get_item_train(idx=idx) + else: + return self.get_item_test(idx=idx) + + def __len__(self): + return len(self.img_fns) + + def read_image(self, image_name): + return cv2.imread(osp.join(self.dataset_path, image_name)) + + def jitter_augmentation(self, img, params): + brightness, contrast, saturation, hue = params + p = np.random.randint(0, 20) / 20 + b = brightness[0] + (brightness[1] - brightness[0]) / 20 * p + img = tvf.adjust_brightness(img=img, brightness_factor=b) + + p = np.random.randint(0, 20) / 20 + c = contrast[0] + (contrast[1] - contrast[0]) / 20 * p + img = tvf.adjust_contrast(img=img, contrast_factor=c) + + p = np.random.randint(0, 20) / 20 + s = saturation[0] + (saturation[1] - saturation[0]) / 20 * p + img = tvf.adjust_saturation(img=img, saturation_factor=s) + + p = np.random.randint(0, 20) / 20 + h = hue[0] + (hue[1] - hue[0]) / 20 * p + img = tvf.adjust_hue(img=img, hue_factor=h) + + return img + + def random_points(self, n, d, h, w): + desc = np.random.random((n, d)) + desc = desc / np.linalg.norm(desc, ord=2, axis=1)[..., None] + xs = np.random.randint(0, w - 1, size=(n, 1)) + ys = np.random.randint(0, h - 1, size=(n, 1)) + kpts = np.hstack([xs, ys]) + return desc, kpts + + def random_points_from_reference(self, n, ref_kpts, ref_segs, ref_xyzs, radius=5): + n_ref = ref_kpts.shape[0] + if n_ref < n: + ref_ids = np.random.choice([i for i in range(n_ref)], size=n).tolist() + else: + ref_ids = [i for i in range(n)] + + new_xs = [] + new_ys = [] + # new_descs = [] + new_segs = [] + new_xyzs = [] + for i in ref_ids: + nx = np.random.randint(-radius, radius) + ref_kpts[i, 0] + ny = np.random.randint(-radius, radius) + ref_kpts[i, 1] + + new_xs.append(nx) + new_ys.append(ny) + # new_descs.append(ref_descs[i]) + new_segs.append(ref_segs[i]) + new_xyzs.append(ref_xyzs[i]) + + new_xs = np.array(new_xs).reshape(n, 1) + new_ys = np.array(new_ys).reshape(n, 1) + new_segs = np.array(new_segs).reshape(n, ) + new_kpts = np.hstack([new_xs, new_ys]) + # new_descs = np.array(new_descs).reshape(n, -1) + new_xyzs = np.array(new_xyzs) + return new_kpts, new_segs, new_xyzs diff --git a/imcui/third_party/pram/dataset/cambridge_landmarks.py b/imcui/third_party/pram/dataset/cambridge_landmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..03f30f367f4ded9ce1d7c2efbaa407ed26725a69 --- /dev/null +++ b/imcui/third_party/pram/dataset/cambridge_landmarks.py @@ -0,0 +1,101 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> cambridge_landmarks +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:41 +==================================================''' +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class CambridgeLandmarks(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='CambridgeLandmarks', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + self.img_fns = [] + with open(osp.join(self.dataset_path, 'dataset_train.txt' if train else 'dataset_test.txt'), 'r') as f: + lines = f.readlines()[3:] # ignore the first 3 lines + for l in lines: + l = l.strip().split()[0] + if train and l not in self.name_to_id.keys(): + continue + if not train and l not in self.img_p3d.keys(): + continue + self.img_fns.append(l) + + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + # lines = f.readlines() + # for l in lines: + # l = l.strip().split() + # self.mean_xyz = np.array([float(v) for v in l[:3]]) + # self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/imcui/third_party/pram/dataset/customdataset.py b/imcui/third_party/pram/dataset/customdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..41ec99ec1540868f3dfbafe00b5585398062e3f8 --- /dev/null +++ b/imcui/third_party/pram/dataset/customdataset.py @@ -0,0 +1,93 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> customdataset.py +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:38 +==================================================''' +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class CustomDataset(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset, + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + if train: + self.img_fns = [self.images[v].name for v in self.images.keys() if + self.images[v].name in self.name_to_id.keys()] + else: + self.img_fns = [] + with open(osp.join(self.dataset_path, 'queries_with_intrinsics.txt'), 'r') as f: + lines = f.readlines() + for l in lines: + self.img_fns.append(l.strip().split()[0]) + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/imcui/third_party/pram/dataset/get_dataset.py b/imcui/third_party/pram/dataset/get_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe28eaa6238b480aae4c64cd08ffe6cd2379c90 --- /dev/null +++ b/imcui/third_party/pram/dataset/get_dataset.py @@ -0,0 +1,89 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> get_dataset +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:40 +==================================================''' +import os.path as osp +import yaml +from dataset.aachen import Aachen +from dataset.twelve_scenes import TwelveScenes +from dataset.seven_scenes import SevenScenes +from dataset.cambridge_landmarks import CambridgeLandmarks +from dataset.customdataset import CustomDataset +from dataset.recdataset import RecDataset + + +def get_dataset(dataset): + if dataset in ['7Scenes', 'S']: + return SevenScenes + elif dataset in ['12Scenes', 'T']: + return TwelveScenes + elif dataset in ['Aachen', 'A']: + return Aachen + elif dataset in ['CambridgeLandmarks', 'C']: + return CambridgeLandmarks + else: + return CustomDataset + + +def compose_datasets(datasets, config, train=True, sample_ratio=None): + sub_sets = [] + for name in datasets: + if name == 'S': + ds_name = '7Scenes' + elif name == 'T': + ds_name = '12Scenes' + elif name == 'A': + ds_name = 'Aachen' + elif name == 'R': + ds_name = 'RobotCar-Seasons' + elif name == 'C': + ds_name = 'CambridgeLandmarks' + else: + ds_name = name + # raise '{} dataset does not exist'.format(name) + landmark_path = osp.join(config['landmark_path'], ds_name) + dataset_path = osp.join(config['dataset_path'], ds_name) + scene_config_path = 'configs/datasets/{:s}.yaml'.format(ds_name) + + with open(scene_config_path, 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + DSet = get_dataset(dataset=ds_name) + + for scene in scene_config['scenes']: + if sample_ratio is None: + scene_sample_ratio = scene_config[scene]['training_sample_ratio'] if train else scene_config[scene][ + 'eval_sample_ratio'] + else: + scene_sample_ratio = sample_ratio + scene_set = DSet(landmark_path=landmark_path, + dataset_path=dataset_path, + scene=scene, + seg_mode=scene_config[scene]['cluster_mode'], + seg_method=scene_config[scene]['cluster_method'], + n_class=scene_config[scene]['n_cluster'] + 1, # including invalid - 0 + dataset=ds_name, + train=train, + nfeatures=config['max_keypoints'] if train else config['eval_max_keypoints'], + min_inliers=config['min_inliers'], + max_inliers=config['max_inliers'], + random_inliers=config['random_inliers'], + with_aug=config['with_aug'], + jitter_params=config['jitter_params'], + scale_params=config['scale_params'], + image_dim=config['image_dim'], + query_p3d_fn=osp.join(config['landmark_path'], ds_name, scene, + 'point3D_query_n{:d}_{:s}_{:s}.npy'.format( + scene_config[scene]['n_cluster'], + scene_config[scene]['cluster_mode'], + scene_config[scene]['cluster_method'])), + query_info_path=osp.join(config['dataset_path'], ds_name, scene, + 'queries_with_intrinsics.txt'), + sample_ratio=scene_sample_ratio, + ) + + sub_sets.append(scene_set) + + return RecDataset(sub_sets=sub_sets) diff --git a/imcui/third_party/pram/dataset/recdataset.py b/imcui/third_party/pram/dataset/recdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9eebd473018ad269eaa6cd8f1ffaab3f5f316ec6 --- /dev/null +++ b/imcui/third_party/pram/dataset/recdataset.py @@ -0,0 +1,95 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> recdataset +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:42 +==================================================''' +import numpy as np +from torch.utils.data import Dataset + + +class RecDataset(Dataset): + def __init__(self, sub_sets=[]): + assert len(sub_sets) >= 1 + + self.sub_sets = sub_sets + self.names = [] + + self.sub_set_index = [] + self.seg_offsets = [] + self.sub_set_item_index = [] + self.dataset_names = [] + self.scene_names = [] + start_index_valid_seg = 1 # start from 1, 0 is for invalid + + total_subset = 0 + for scene_set in sub_sets: # [0, n_class] + name = scene_set.dataset + self.names.append(name) + n_samples = len(scene_set) + + n_class = scene_set.n_class + self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))] + start_index_valid_seg = start_index_valid_seg + n_class - 1 + + self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)] + self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)] + + # self.dataset_names = self.dataset_names + [name for k in range(n_samples)] + self.scene_names = self.scene_names + [name for k in range(n_samples)] + total_subset += 1 + + self.n_class = start_index_valid_seg + + print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class, + len(sub_sets), self.names)) + + def __len__(self): + return len(self.sub_set_item_index) + + def __getitem__(self, idx): + subset_idx = self.sub_set_index[idx] + item_idx = self.sub_set_item_index[idx] + scene_name = self.scene_names[idx] + + out = self.sub_sets[subset_idx][item_idx] + + org_gt_seg = out['gt_seg'] + org_gt_cls = out['gt_cls'] + org_gt_cls_dist = out['gt_cls_dist'] + org_gt_n_seg = out['gt_n_seg'] + offset = self.seg_offsets[idx] + org_n_class = self.sub_sets[subset_idx].n_class + + gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int) # [0, ..., n_features] + gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls = np.zeros(shape=(self.n_class,), dtype=int) + gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float) + + # copy invalid segments + gt_n_seg[0] = org_gt_n_seg[0] + gt_cls[0] = org_gt_cls[0] + gt_cls_dist[0] = org_gt_cls_dist[0] + # print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg) + + # copy valid segments + gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1 # [0, ..., 1023] + gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:] # [0...,n_seg] + gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:] # [0, ..., n_seg] + gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:] # [0, ..., n_seg] + + out['gt_seg'] = gt_seg + out['gt_cls'] = gt_cls + out['gt_cls_dist'] = gt_cls_dist + out['gt_n_seg'] = gt_n_seg + + # print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg) + out['scene_name'] = scene_name + + # out['org_gt_seg'] = org_gt_seg + # out['org_gt_n_seg'] = org_gt_n_seg + # out['org_gt_cls'] = org_gt_cls + # out['org_gt_cls_dist'] = org_gt_cls_dist + + return out diff --git a/imcui/third_party/pram/dataset/seven_scenes.py b/imcui/third_party/pram/dataset/seven_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc29b29d3b935e45129a35b502117067816433a --- /dev/null +++ b/imcui/third_party/pram/dataset/seven_scenes.py @@ -0,0 +1,115 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> seven_scenes +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:36 +==================================================''' +import os +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class SevenScenes(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='7Scenes', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + if self.train: + split_fn = osp.join(self.dataset_path, 'TrainSplit.txt') + else: + split_fn = osp.join(self.dataset_path, 'TestSplit.txt') + + self.img_fns = [] + with open(split_fn, 'r') as f: + lines = f.readlines() + for l in lines: + seq = int(l.strip()[8:]) + fns = os.listdir(osp.join(self.dataset_path, osp.join('seq-{:02d}'.format(seq)))) + fns = sorted(fns) + nf = 0 + for fn in fns: + if fn.find('png') >= 0: + if train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.name_to_id.keys(): + continue + if not train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.img_p3d.keys(): + continue + if nf % sample_ratio == 0: + self.img_fns.append('seq-{:02d}'.format(seq) + '/' + fn) + nf += 1 + + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + # lines = f.readlines() + # for l in lines: + # l = l.strip().split() + # self.mean_xyz = np.array([float(v) for v in l[:3]]) + # self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/imcui/third_party/pram/dataset/twelve_scenes.py b/imcui/third_party/pram/dataset/twelve_scenes.py new file mode 100644 index 0000000000000000000000000000000000000000..34fcc7f46b6d4315d9ebca69043a262310adc453 --- /dev/null +++ b/imcui/third_party/pram/dataset/twelve_scenes.py @@ -0,0 +1,121 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> twelve_scenes +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:37 +==================================================''' +import os +import os.path as osp +import numpy as np +from colmap_utils.read_write_model import read_model +import torchvision.transforms as tvt +from dataset.basicdataset import BasicDataset + + +class TwelveScenes(BasicDataset): + def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='12Scenes', + nfeatures=1024, + query_p3d_fn=None, + train=True, + with_aug=False, + min_inliers=0, + max_inliers=4096, + random_inliers=False, + jitter_params=None, + scale_params=None, + image_dim=3, + query_info_path=None, + sample_ratio=1, + ): + self.landmark_path = osp.join(landmark_path, scene) + self.dataset_path = osp.join(dataset_path, scene) + self.n_class = n_class + self.dataset = dataset + '/' + scene + self.nfeatures = nfeatures + self.with_aug = with_aug + self.jitter_params = jitter_params + self.scale_params = scale_params + self.image_dim = image_dim + self.train = train + self.min_inliers = min_inliers + self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures + self.random_inliers = random_inliers + self.image_prefix = '' + + train_transforms = [] + if self.with_aug: + train_transforms.append(tvt.ColorJitter( + brightness=jitter_params['brightness'], + contrast=jitter_params['contrast'], + saturation=jitter_params['saturation'], + hue=jitter_params['hue'])) + if jitter_params['blur'] > 0: + train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur']))) + self.train_transforms = tvt.Compose(train_transforms) + + if train: + self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin') + self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0} + + # only for testing of query images + if not self.train: + data = np.load(query_p3d_fn, allow_pickle=True)[()] + self.img_p3d = data + else: + self.img_p3d = {} + + with open(osp.join(self.dataset_path, 'split.txt'), 'r') as f: + l = f.readline() + l = l.strip().split(' ') # sequence0 [frames=357] [start=0 ; end=356], first sequence for testing + start_img_id = l[-3].split('=')[-1] + end_img_id = l[-1].split('=')[-1][:-1] + test_start_img_id = int(start_img_id) + test_end_img_id = int(end_img_id) + + self.img_fns = [] + fns = os.listdir(osp.join(self.dataset_path, 'data')) + fns = sorted(fns) + nf = 0 + for fn in fns: + if fn.find('jpg') >= 0: # frame-001098.color.jpg + frame_id = int(fn.split('.')[0].split('-')[-1]) + if not train and frame_id > test_end_img_id: + continue + if train and frame_id <= test_end_img_id: + continue + + if train and 'data' + '/' + fn not in self.name_to_id.keys(): + continue + + if not train and 'data' + '/' + fn not in self.img_p3d.keys(): + continue + if nf % sample_ratio == 0: + self.img_fns.append('data' + '/' + fn) + nf += 1 + + print('Load {} images from {} for {}...'.format(len(self.img_fns), + self.dataset, 'training' if train else 'eval')) + + data = np.load(osp.join(self.landmark_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)), + allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + xyzs = data['xyz'] + self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])} + + # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f: + # lines = f.readlines() + # for l in lines: + # l = l.strip().split() + # self.mean_xyz = np.array([float(v) for v in l[:3]]) + # self.scale_xyz = np.array([float(v) for v in l[3:]]) + + if not train: + self.query_info = self.read_query_info(path=query_info_path) + + self.nfeatures = nfeatures + self.feature_dir = osp.join(self.landmark_path, 'feats') + self.feats = {} diff --git a/imcui/third_party/pram/dataset/utils.py b/imcui/third_party/pram/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8132662c540ae28de32494a5abff6e679064f5 --- /dev/null +++ b/imcui/third_party/pram/dataset/utils.py @@ -0,0 +1,31 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> utils +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:31 +==================================================''' +import torch + + +def normalize_size(x, size, scale=0.7): + size = size.reshape([1, 2]) + norm_fac = size.max() + 0.5 + return (x - size / 2) / (norm_fac * scale) + + +def collect_batch(batch): + out = {} + # if len(batch) == 0: + # return batch + # else: + for k in batch[0].keys(): + tmp = [] + for v in batch: + tmp.append(v[k]) + if isinstance(batch[0][k], str) or isinstance(batch[0][k], list): + out[k] = tmp + else: + out[k] = torch.cat([torch.from_numpy(i)[None] for i in tmp], dim=0) + + return out diff --git a/imcui/third_party/pram/environment.yml b/imcui/third_party/pram/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..bf1c2111660046500e25c9ff28e66d470c7f68a9 --- /dev/null +++ b/imcui/third_party/pram/environment.yml @@ -0,0 +1,173 @@ +name: pram +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - binutils_impl_linux-64=2.38=h2a08ee3_1 + - bzip2=1.0.8=h5eee18b_5 + - ca-certificates=2024.3.11=h06a4308_0 + - gcc=12.1.0=h9ea6d83_10 + - gcc_impl_linux-64=12.1.0=hea43390_17 + - kernel-headers_linux-64=2.6.32=he073ed8_17 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-devel_linux-64=12.1.0=h1ec3361_17 + - libgcc-ng=13.2.0=h807b86a_5 + - libgomp=13.2.0=h807b86a_5 + - libsanitizer=12.1.0=ha89aaad_17 + - libstdcxx-ng=13.2.0=h7e041cc_5 + - libuuid=1.41.5=h5eee18b_0 + - ncurses=6.4=h6a678d5_0 + - openssl=3.2.1=hd590300_1 + - pip=23.3.1=py310h06a4308_0 + - python=3.10.14=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.2.2=py310h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - sysroot_linux-64=2.12=he073ed8_17 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py310h06a4308_0 + - xz=5.4.6=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - addict==2.4.0 + - aiofiles==23.2.1 + - aiohttp==3.9.3 + - aioopenssl==0.6.0 + - aiosasl==0.5.0 + - aiosignal==1.3.1 + - aioxmpp==0.13.3 + - asttokens==2.4.1 + - async-timeout==4.0.3 + - attrs==23.2.0 + - babel==2.14.0 + - benbotasync==3.0.2 + - blinker==1.7.0 + - certifi==2024.2.2 + - cffi==1.16.0 + - charset-normalizer==3.3.2 + - click==8.1.7 + - colorama==0.4.6 + - comm==0.2.2 + - configargparse==1.7 + - contourpy==1.2.1 + - crayons==0.4.0 + - cryptography==42.0.5 + - cycler==0.12.1 + - dash==2.16.1 + - dash-core-components==2.0.0 + - dash-html-components==2.0.0 + - dash-table==5.0.0 + - decorator==5.1.1 + - dnspython==2.6.1 + - einops==0.7.0 + - exceptiongroup==1.2.0 + - executing==2.0.1 + - fastjsonschema==2.19.1 + - filelock==3.13.3 + - flask==3.0.2 + - fonttools==4.50.0 + - fortniteapiasync==0.1.7 + - fortnitepy==3.6.9 + - frozenlist==1.4.1 + - fsspec==2024.3.1 + - h5py==3.10.0 + - html5tagger==1.3.0 + - httptools==0.6.1 + - idna==3.6 + - importlib-metadata==7.1.0 + - ipython==8.23.0 + - ipywidgets==8.1.2 + - itsdangerous==2.1.2 + - jedi==0.19.1 + - jinja2==3.1.3 + - joblib==1.3.2 + - jsonschema==4.21.1 + - jsonschema-specifications==2023.12.1 + - jupyter-core==5.7.2 + - jupyterlab-widgets==3.0.10 + - kiwisolver==1.4.5 + - lxml==4.9.4 + - markupsafe==2.1.5 + - matplotlib==3.8.4 + - matplotlib-inline==0.1.6 + - mpmath==1.3.0 + - multidict==6.0.5 + - nbformat==5.10.4 + - nest-asyncio==1.6.0 + - networkx==3.2.1 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.19.3 + - nvidia-nvjitlink-cu12==12.4.127 + - nvidia-nvtx-cu12==12.1.105 + - open3d==0.18.0 + - opencv-contrib-python==4.5.5.64 + - packaging==24.0 + - pandas==2.2.1 + - parso==0.8.3 + - pexpect==4.9.0 + - pillow==10.3.0 + - platformdirs==4.2.0 + - plotly==5.20.0 + - prompt-toolkit==3.0.43 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyasn1==0.6.0 + - pyasn1-modules==0.4.0 + - pybind11==2.12.0 + - pycolmap==0.6.1 + - pycparser==2.22 + - pygments==2.17.2 + - pyopengl==3.1.7 + - pyopengl-accelerate==3.1.7 + - pyopenssl==24.1.0 + - pyparsing==3.1.2 + - pyquaternion==0.9.9 + - python-dateutil==2.9.0.post0 + - pytz==2024.1 + - pyyaml==6.0.1 + - referencing==0.34.0 + - requests==2.31.0 + - retrying==1.3.4 + - rpds-py==0.18.0 + - sanic==23.12.1 + - sanic-routing==23.12.0 + - scikit-learn==1.4.1.post1 + - scipy==1.13.0 + - six==1.16.0 + - sortedcollections==2.1.0 + - sortedcontainers==2.4.0 + - stack-data==0.6.3 + - sympy==1.12 + - tenacity==8.2.3 + - threadpoolctl==3.4.0 + - torch==2.2.2 + - torchvision==0.17.2 + - tqdm==4.66.2 + - tracerite==1.1.1 + - traitlets==5.14.2 + - triton==2.2.0 + - typing-extensions==4.10.0 + - tzdata==2024.1 + - tzlocal==5.2 + - ujson==5.9.0 + - urllib3==2.2.1 + - uvloop==0.15.2 + - wcwidth==0.2.13 + - websockets==12.0 + - werkzeug==3.0.2 + - widgetsnbextension==4.0.10 + - yaml2==0.0.1 + - yarl==1.9.4 + - zipp==3.18.1 diff --git a/imcui/third_party/pram/inference.py b/imcui/third_party/pram/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..29ccd76911f0b2ff8dc82fc28c712cf1d19d40be --- /dev/null +++ b/imcui/third_party/pram/inference.py @@ -0,0 +1,62 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> inference +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 03/04/2024 16:06 +==================================================''' +import argparse +import torch +import torchvision.transforms.transforms as tvt +import yaml +from nets.load_segnet import load_segnet +from nets.sfd2 import load_sfd2 +from dataset.get_dataset import compose_datasets + +parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--config', type=str, required=True, help='config of specifications') +parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks') +parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth') +parser.add_argument('--rec_weight_path', type=str, required=True, help='recognition weight') +parser.add_argument('--online', action='store_true', help='online visualization with pangolin') + +if __name__ == '__main__': + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = yaml.load(f, Loader=yaml.Loader) + config['landmark_path'] = args.landmark_path + + feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval() + print('Load SFD2 weight from {:s}'.format(args.feat_weight_path)) + + # rec_model = get_model(config=config) + rec_model = load_segnet(network=config['network'], + n_class=config['n_class'], + desc_dim=256 if config['use_mid_feature'] else 128, + n_layers=config['layers'], + output_dim=config['output_dim']) + state_dict = torch.load(args.rec_weight_path, map_location='cpu')['model'] + rec_model.load_state_dict(state_dict, strict=True) + print('Load recognition weight from {:s}'.format(args.rec_weight_path)) + + img_transforms = [] + img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + img_transforms = tvt.Compose(img_transforms) + + dataset = config['dataset'] + if not args.online: + from localization.loc_by_rec_eval import loc_by_rec_eval + + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1) + config['n_class'] = test_set.n_class + + loc_by_rec_eval(rec_model=rec_model.cuda().eval(), + loader=test_set, + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) + else: + from localization.loc_by_rec_online import loc_by_rec_online + + loc_by_rec_online(rec_model=rec_model.cuda().eval(), + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) diff --git a/imcui/third_party/pram/localization/base_model.py b/imcui/third_party/pram/localization/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..432f49c325d39aa44efb0c3106abf7e376c8244e --- /dev/null +++ b/imcui/third_party/pram/localization/base_model.py @@ -0,0 +1,45 @@ +from abc import ABCMeta, abstractmethod +from torch import nn +from copy import copy +import inspect + + +class BaseModel(nn.Module, metaclass=ABCMeta): + default_conf = {} + required_data_keys = [] + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + self.conf = conf = {**self.default_conf, **conf} + self.required_data_keys = copy(self.required_data_keys) + self._init(conf) + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + for key in self.required_data_keys: + assert key in data, 'Missing key {} in data'.format(key) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + +def dynamic_load(root, model): + module_path = f'{root.__name__}.{model}' + module = __import__(module_path, fromlist=['']) + classes = inspect.getmembers(module, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == module_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseModel)] + assert len(classes) == 1, classes + return classes[0][1] + # return getattr(module, 'Model') diff --git a/imcui/third_party/pram/localization/camera.py b/imcui/third_party/pram/localization/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d77af63bcac68b87acd6f5ddc19d92c7d99d07 --- /dev/null +++ b/imcui/third_party/pram/localization/camera.py @@ -0,0 +1,11 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> camera +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 11:27 +==================================================''' +import collections + +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) diff --git a/imcui/third_party/pram/localization/extract_features.py b/imcui/third_party/pram/localization/extract_features.py new file mode 100644 index 0000000000000000000000000000000000000000..cd3f85c53dafd33fe737fdb9e79eeee1bd1c600b --- /dev/null +++ b/imcui/third_party/pram/localization/extract_features.py @@ -0,0 +1,256 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> extract_features.py +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 14:49 +==================================================''' +import os +import os.path as osp +import h5py +import numpy as np +import progressbar +import yaml +import torch +import cv2 +import torch.utils.data as Data +from tqdm import tqdm +from types import SimpleNamespace +import logging +import pprint +from pathlib import Path +import argparse +from nets.sfd2 import ResNet4x, extract_sfd2_return +from nets.superpoint import SuperPoint, extract_sp_return + +confs = { + 'superpoint-n4096': { + 'output': 'feats-superpoint-n4096', + 'model': { + 'name': 'superpoint', + 'outdim': 256, + 'use_stability': False, + 'nms_radius': 3, + 'max_keypoints': 4096, + 'conf_th': 0.005, + 'multiscale': False, + 'scales': [1.0], + 'model_fn': osp.join(os.getcwd(), + "weights/superpoint_v1.pth"), + }, + 'preprocessing': { + 'grayscale': True, + 'resize_max': False, + }, + }, + + 'resnet4x-20230511-210205-pho-0005': { + 'output': 'feats-resnet4x-20230511-210205-pho-0005', + 'model': { + 'outdim': 128, + 'name': 'resnet4x', + 'use_stability': False, + 'max_keypoints': 4096, + 'conf_th': 0.005, + 'multiscale': False, + 'scales': [1.0], + 'model_fn': osp.join(os.getcwd(), + "weights/sfd2_20230511_210205_resnet4x.79.pth"), + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': False, + }, + 'mask': False, + }, + + 'sfd2': { + 'output': 'feats-sfd2', + 'model': { + 'outdim': 128, + 'name': 'resnet4x', + 'use_stability': False, + 'max_keypoints': 4096, + 'conf_th': 0.005, + 'multiscale': False, + 'scales': [1.0], + 'model_fn': osp.join(os.getcwd(), + "weights/sfd2_20230511_210205_resnet4x.79.pth"), + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': False, + }, + 'mask': False, + }, +} + + +class ImageDataset(Data.Dataset): + default_conf = { + 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'], + 'grayscale': False, + 'resize_max': None, + 'resize_force': False, + } + + def __init__(self, root, conf, image_list=None, + mask_root=None): + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.root = root + + self.paths = [] + if image_list is None: + for g in conf.globs: + self.paths += list(Path(root).glob('**/' + g)) + if len(self.paths) == 0: + raise ValueError(f'Could not find any image in root: {root}.') + self.paths = [i.relative_to(root) for i in self.paths] + else: + with open(image_list, "r") as f: + lines = f.readlines() + for l in lines: + l = l.strip() + self.paths.append(Path(l)) + + logging.info(f'Found {len(self.paths)} images in root {root}.') + + if mask_root is not None: + self.mask_root = mask_root + else: + self.mask_root = None + + def __getitem__(self, idx): + path = self.paths[idx] + if self.conf.grayscale: + mode = cv2.IMREAD_GRAYSCALE + else: + mode = cv2.IMREAD_COLOR + image = cv2.imread(str(self.root / path), mode) + if not self.conf.grayscale: + image = image[:, :, ::-1] # BGR to RGB + if image is None: + raise ValueError(f'Cannot read image {str(path)}.') + image = image.astype(np.float32) + size = image.shape[:2][::-1] + w, h = size + + if self.conf.resize_max and (self.conf.resize_force + or max(w, h) > self.conf.resize_max): + scale = self.conf.resize_max / max(h, w) + h_new, w_new = int(round(h * scale)), int(round(w * scale)) + image = cv2.resize( + image, (w_new, h_new), interpolation=cv2.INTER_CUBIC) + + if self.conf.grayscale: + image = image[None] + else: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + image = image / 255. + + data = { + 'name': str(path), + 'image': image, + 'original_size': np.array(size), + } + + if self.mask_root is not None: + mask_path = Path(str(path).replace("jpg", "png")) + if osp.exists(mask_path): + mask = cv2.imread(str(self.mask_root / mask_path)) + mask = cv2.resize(mask, dsize=(image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST) + else: + mask = np.zeros(shape=(image.shape[1], image.shape[2], 3), dtype=np.uint8) + + data['mask'] = mask + + return data + + def __len__(self): + return len(self.paths) + + +def get_model(model_name, weight_path, outdim=128, **kwargs): + if model_name == 'superpoint': + model = SuperPoint(config={ + 'descriptor_dim': 256, + 'nms_radius': 4, + 'keypoint_threshold': 0.005, + 'max_keypoints': -1, + 'remove_borders': 4, + 'weight_path': weight_path, + }).eval() + + extractor = extract_sp_return + + if model_name == 'resnet4x': + model = ResNet4x(outdim=outdim).eval() + model.load_state_dict(torch.load(weight_path)['state_dict'], strict=True) + extractor = extract_sfd2_return + + return model, extractor + + +@torch.no_grad() +def main(conf, image_dir, export_dir): + logging.info('Extracting local features with configuration:' + f'\n{pprint.pformat(conf)}') + model, extractor = get_model(model_name=conf['model']['name'], weight_path=conf["model"]["model_fn"], + use_stability=conf['model']['use_stability'], outdim=conf['model']['outdim']) + model = model.cuda() + loader = ImageDataset(image_dir, + conf['preprocessing'], + image_list=args.image_list, + mask_root=None) + loader = torch.utils.data.DataLoader(loader, num_workers=4) + + os.makedirs(export_dir, exist_ok=True) + feature_path = Path(export_dir, conf['output'] + '.h5') + feature_path.parent.mkdir(exist_ok=True, parents=True) + feature_file = h5py.File(str(feature_path), 'a') + + with tqdm(total=len(loader)) as t: + for idx, data in enumerate(loader): + t.update() + pred = extractor(model, img=data["image"], + topK=conf["model"]["max_keypoints"], + mask=None, + conf_th=conf["model"]["conf_th"], + scales=conf["model"]["scales"], + ) + + # pred = {k: v[0].cpu().numpy() for k, v in pred.items()} + pred['descriptors'] = pred['descriptors'].transpose() + + t.set_postfix(npoints=pred['keypoints'].shape[0]) + # print(pred['keypoints'].shape) + + pred['image_size'] = original_size = data['original_size'][0].numpy() + # pred['descriptors'] = pred['descriptors'].T + if 'keypoints' in pred.keys(): + size = np.array(data['image'].shape[-2:][::-1]) + scales = (original_size / size).astype(np.float32) + pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5 + + grp = feature_file.create_group(data['name'][0]) + for k, v in pred.items(): + # print(k, v.shape) + grp.create_dataset(k, data=v) + + del pred + + feature_file.close() + logging.info('Finished exporting features.') + + return feature_path + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image_dir', type=Path, required=True) + parser.add_argument('--image_list', type=str, default=None) + parser.add_argument('--mask_dir', type=Path, default=None) + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) + args = parser.parse_args() + main(confs[args.conf], args.image_dir, args.export_dir) diff --git a/imcui/third_party/pram/localization/frame.py b/imcui/third_party/pram/localization/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..467a0f31a9c62a19b4435c71add6d08e34b051f3 --- /dev/null +++ b/imcui/third_party/pram/localization/frame.py @@ -0,0 +1,195 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> frame +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 01/03/2024 10:08 +==================================================''' +from collections import defaultdict + +import numpy as np +import torch +import pycolmap + +from localization.camera import Camera +from localization.utils import compute_pose_error + + +class Frame: + def __init__(self, image: np.ndarray, camera: pycolmap.Camera, id: int, name: str = None, qvec=None, tvec=None, + scene_name=None, + reference_frame_id=None): + self.image = image + self.camera = camera + self.id = id + self.name = name + self.image_size = np.array([camera.height, camera.width]) + self.qvec = qvec + self.tvec = tvec + self.scene_name = scene_name + self.reference_frame_id = reference_frame_id + + self.keypoints = None # [N, 3] + self.descriptors = None # [N, D] + self.segmentations = None # [N C] + self.seg_scores = None # [N C] + self.seg_ids = None # [N, 1] + self.point3D_ids = None # [N, 1] + self.xyzs = None + + self.gt_qvec = None + self.gt_tvec = None + + self.matched_scene_name = None + self.matched_keypoints = None + self.matched_keypoint_ids = None + self.matched_xyzs = None + self.matched_point3D_ids = None + self.matched_inliers = None + self.matched_sids = None + self.matched_order = None + + self.refinement_reference_frame_ids = None + self.image_rec = None + self.image_matching = None + self.image_inlier = None + self.reference_frame_name = None + self.image_matching_tmp = None + self.image_inlier_tmp = None + self.reference_frame_name_tmp = None + + self.tracking_status = None + + self.time_feat = 0 + self.time_rec = 0 + self.time_loc = 0 + self.time_ref = 0 + + def update_point3ds_old(self): + pt = torch.from_numpy(self.keypoints[:, :2]).unsqueeze(-1) # [M 2 1] + mpt = torch.from_numpy(self.matched_keypoints[:, :2].transpose()).unsqueeze(0) # [1 2 N] + dist = torch.sqrt(torch.sum((pt - mpt) ** 2, dim=1)) + values, ids = torch.topk(dist, dim=1, k=1, largest=False) + values = values[:, 0].numpy() + ids = ids[:, 0].numpy() + mask = (values < 1) # 1 pixel error + self.point3D_ids = np.zeros(shape=(self.keypoints.shape[0],), dtype=int) - 1 + self.point3D_ids[mask] = self.matched_point3D_ids[ids[mask]] + + # self.xyzs = np.zeros(shape=(self.keypoints.shape[0], 3), dtype=float) + inlier_mask = self.matched_inliers + self.xyzs[mask] = self.matched_xyzs[ids[mask]] + self.seg_ids[mask] = self.matched_sids[ids[mask]] + + def update_point3ds(self): + # print('Frame: update_point3ds: ', self.matched_keypoint_ids.shape, self.matched_xyzs.shape, + # self.matched_sids.shape, self.matched_point3D_ids.shape) + self.xyzs[self.matched_keypoint_ids] = self.matched_xyzs + self.seg_ids[self.matched_keypoint_ids] = self.matched_sids + self.point3D_ids[self.matched_keypoint_ids] = self.matched_point3D_ids + + def add_keypoints(self, keypoints: np.ndarray, descriptors: np.ndarray): + self.keypoints = keypoints + self.descriptors = descriptors + self.initialize_localization_variables() + + def add_segmentations(self, segmentations: torch.Tensor, filtering_threshold: float): + ''' + :param segmentations: [number_points number_labels] + :return: + ''' + seg_scores = torch.softmax(segmentations, dim=-1) + if filtering_threshold > 0: + scores_background = seg_scores[:, 0] + non_bg_mask = (scores_background < filtering_threshold) + print('pre filtering before: ', self.keypoints.shape) + if torch.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]: + self.keypoints = self.keypoints[non_bg_mask.cpu().numpy()] + self.descriptors = self.descriptors[non_bg_mask.cpu().numpy()] + # print('pre filtering after: ', self.keypoints.shape) + + # update localization variables + self.initialize_localization_variables() + + segmentations = segmentations[non_bg_mask] + seg_scores = seg_scores[non_bg_mask] + print('pre filtering after: ', self.keypoints.shape) + + # extract initial segmentation info + self.segmentations = segmentations.cpu().numpy() + self.seg_scores = seg_scores.cpu().numpy() + self.seg_ids = segmentations.max(dim=-1)[1].cpu().numpy() - 1 # should start from 0 + + def filter_keypoints(self, seg_scores: np.ndarray, filtering_threshold: float): + scores_background = seg_scores[:, 0] + non_bg_mask = (scores_background < filtering_threshold) + print('pre filtering before: ', self.keypoints.shape) + if np.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]: + self.keypoints = self.keypoints[non_bg_mask] + self.descriptors = self.descriptors[non_bg_mask] + print('pre filtering after: ', self.keypoints.shape) + + # update localization variables + self.initialize_localization_variables() + return non_bg_mask + else: + print('pre filtering after: ', self.keypoints.shape) + return None + + def compute_pose_error(self, pred_qvec=None, pred_tvec=None): + if pred_qvec is not None and pred_tvec is not None: + if self.gt_qvec is not None and self.gt_tvec is not None: + return compute_pose_error(pred_qcw=pred_qvec, pred_tcw=pred_tvec, + gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec) + else: + return 100, 100 + + if self.qvec is None or self.tvec is None or self.gt_qvec is None or self.gt_tvec is None: + return 100, 100 + else: + err_q, err_t = compute_pose_error(pred_qcw=self.qvec, pred_tcw=self.tvec, + gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec) + return err_q, err_t + + def get_intrinsics(self) -> np.ndarray: + camera_model = self.camera.model.name + params = self.camera.params + if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = params[0] + cx = params[1] + cy = params[2] + elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = params[0] + fy = params[1] + cx = params[2] + cy = params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K + + def get_dominate_seg_id(self): + counts = np.bincount(self.seg_ids[self.seg_ids > 0]) + return np.argmax(counts) + + def clear_localization_track(self): + self.matched_scene_name = None + self.matched_keypoints = None + self.matched_xyzs = None + self.matched_point3D_ids = None + self.matched_inliers = None + self.matched_sids = None + + self.refinement_reference_frame_ids = None + + def initialize_localization_variables(self): + nkpt = self.keypoints.shape[0] + self.seg_ids = np.zeros(shape=(nkpt,), dtype=int) - 1 + self.point3D_ids = np.zeros(shape=(nkpt,), dtype=int) - 1 + self.xyzs = np.zeros(shape=(nkpt, 3), dtype=float) diff --git a/imcui/third_party/pram/localization/loc_by_rec_eval.py b/imcui/third_party/pram/localization/loc_by_rec_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..f69b4ac3fde0547947abe983b1f5a4a4af55f974 --- /dev/null +++ b/imcui/third_party/pram/localization/loc_by_rec_eval.py @@ -0,0 +1,299 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> loc_by_rec +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 08/02/2024 15:26 +==================================================''' +import torch +from torch.autograd import Variable +from localization.multimap3d import MultiMap3D +from localization.frame import Frame +import yaml, cv2, time +import numpy as np +import os.path as osp +import threading +import os +from tqdm import tqdm +from recognition.vis_seg import vis_seg_point, generate_color_dic +from tools.metrics import compute_iou, compute_precision +from localization.tracker import Tracker +from localization.utils import read_query_info +from localization.camera import Camera + + +def loc_by_rec_eval(rec_model, loader, config, local_feat, img_transforms=None): + n_epoch = int(config['weight_path'].split('.')[1]) + save_fn = osp.join(config['localization']['save_path'], + config['weight_path'].split('/')[0] + '_{:d}'.format(n_epoch) + '_{:d}'.format( + config['feat_dim'])) + tag = 'k{:d}_th{:d}_mm{:d}_mi{:d}'.format(config['localization']['seg_k'], config['localization']['threshold'], + config['localization']['min_matches'], + config['localization']['min_inliers']) + if config['localization']['do_refinement']: + tag += '_op{:d}'.format(config['localization']['covisibility_frame']) + if config['localization']['with_compress']: + tag += '_comp' + + save_fn = save_fn + '_' + tag + + save = config['localization']['save'] + save = config['localization']['save'] + if save: + save_dir = save_fn + os.makedirs(save_dir, exist_ok=True) + else: + save_dir = None + + seg_color = generate_color_dic(n_seg=2000) + dataset_path = config['dataset_path'] + show = config['localization']['show'] + if show: + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + + locMap = MultiMap3D(config=config, save_dir=None) + # start tracker + mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config) + + dataset_name = config['dataset'][0] + all_scene_query_info = {} + with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + scenes = scene_config['scenes'] + for scene in scenes: + query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path']) + query_info = read_query_info(query_fn=query_path) + all_scene_query_info[dataset_name + '/' + scene] = query_info + # print(scene, query_info.keys()) + + tracking = False + + full_log = '' + failed_cases = [] + success_cases = [] + poses = {} + err_ths_cnt = [0, 0, 0, 0] + + seg_results = {} + time_results = { + 'feat': [], + 'rec': [], + 'loc': [], + 'ref': [], + 'total': [], + } + n_total = 0 + + loc_scene_names = config['localization']['loc_scene_name'] + # loader = loader[8990:] + for bid, pred in tqdm(enumerate(loader), total=len(loader)): + pred = loader[bid] + image_name = pred['file_name'] # [0] + scene_name = pred['scene_name'] # [0] # dataset_scene + if len(loc_scene_names) > 0: + skip = True + for loc_scene in loc_scene_names: + if scene_name.find(loc_scene) > 0: + skip = False + break + if skip: + continue + with torch.no_grad(): + for k in pred: + if k.find('name') >= 0: + continue + if k != 'image0' and k != 'image1' and k != 'depth0' and k != 'depth1': + if type(pred[k]) == np.ndarray: + pred[k] = Variable(torch.from_numpy(pred[k]).float().cuda())[None] + elif type(pred[k]) == torch.Tensor: + pred[k] = Variable(pred[k].float().cuda()) + elif type(pred[k]) == list: + continue + else: + pred[k] = Variable(torch.stack(pred[k]).float().cuda()) + print('scene: ', scene_name, image_name) + + n_total += 1 + with torch.no_grad(): + img = pred['image'] + while isinstance(img, list): + img = img[0] + + new_im = torch.from_numpy(img).permute(2, 0, 1).cuda().float() + if img_transforms is not None: + new_im = img_transforms(new_im)[None] + else: + new_im = new_im[None] + img = (img * 255).astype(np.uint8) + + fn = image_name + camera_model, width, height, params = all_scene_query_info[scene_name][fn] + camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params) + curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=scene_name) + gt_sub_map = locMap.sub_maps[curr_frame.scene_name] + if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys(): + curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec'] + curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec'] + + t_start = time.time() + encoder_out = local_feat.extract_local_global(data={'image': new_im}, + config= + { + # 'min_keypoints': 128, + 'max_keypoints': config['eval_max_keypoints'], + } + ) + t_feat = time.time() - t_start + # global_descriptors_cuda = encoder_out['global_descriptors'] + # scores_cuda = encoder_out['scores'][0][None] + # kpts_cuda = encoder_out['keypoints'][0][None] + # descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1) + + sparse_scores = pred['scores'] + sparse_descs = pred['descriptors'] + sparse_kpts = pred['keypoints'] + gt_seg = pred['gt_seg'] + + curr_frame.add_keypoints(keypoints=np.hstack([sparse_kpts[0].cpu().numpy(), + sparse_scores[0].cpu().numpy().reshape(-1, 1)]), + descriptors=sparse_descs[0].cpu().numpy()) + curr_frame.time_feat = t_feat + + t_start = time.time() + _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'], + semi_descs=encoder_out['mid_features'], + # kpts=kpts_cuda[0], + kpts=sparse_kpts[0], + norm_desc=config['norm_desc']) + rec_out = rec_model({'scores': sparse_scores, + 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1), + 'keypoints': sparse_kpts, + 'image': new_im}) + t_rec = time.time() - t_start + curr_frame.time_rec = t_rec + + pred = { + # 'scores': scores_cuda, + # 'keypoints': kpts_cuda, + # 'descriptors': descriptors_cuda, + # 'global_descriptors': global_descriptors_cuda, + 'image_size': np.array([img.shape[1], img.shape[0]])[None], + } + + pred = {**pred, **rec_out} + pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C] + + pred_seg = pred_seg[0].cpu().numpy() + kpts = sparse_kpts[0].cpu().numpy() + img_pred_seg = vis_seg_point(img=img, kpts=kpts, segs=pred_seg, seg_color=seg_color, radius=9) + show_text = 'kpts: {:d}'.format(kpts.shape[0]) + img_pred_seg = cv2.putText(img=img_pred_seg, text=show_text, + org=(50, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + curr_frame.image_rec = img_pred_seg + + if show: + cv2.imshow('img', img) + key = cv2.waitKey(1) + if key == ord('q'): + exit(0) + elif key == ord('s'): + show_time = -1 + elif key == ord('c'): + show_time = 1 + + segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C] + curr_frame.add_segmentations(segmentations=segmentations, + filtering_threshold=config['localization']['pre_filtering_th']) + + # Step1: do tracker first + success = not mTracker.lost and tracking + if success: + success = mTracker.run(frame=curr_frame) + if not success: + success = locMap.run(q_frame=curr_frame) + if success: + curr_frame.update_point3ds() + if tracking: + mTracker.lost = False + mTracker.last_frame = curr_frame + # ''' + pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C] + pred_seg = pred_seg[0].cpu().numpy() + gt_seg = gt_seg[0].cpu().numpy() + iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=pred_seg.shape[0], + ignored_ids=[0]) # 0 - background + prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0]) + + kpts = sparse_kpts[0].cpu().numpy() + if scene not in seg_results.keys(): + seg_results[scene] = { + 'day': { + 'prec': [], + 'iou': [], + 'kpts': [], + }, + 'night': { + 'prec': [], + 'iou': [], + 'kpts': [], + + } + } + if fn.find('night') >= 0: + seg_results[scene]['night']['prec'].append(prec) + seg_results[scene]['night']['iou'].append(iou) + seg_results[scene]['night']['kpts'].append(kpts.shape[0]) + else: + seg_results[scene]['day']['prec'].append(prec) + seg_results[scene]['day']['iou'].append(iou) + seg_results[scene]['day']['kpts'].append(kpts.shape[0]) + + print_text = 'name: {:s}, kpts: {:d}, iou: {:.3f}, prec: {:.3f}'.format(fn, kpts.shape[0], iou, + prec) + print(print_text) + # ''' + + t_feat = curr_frame.time_feat + t_rec = curr_frame.time_rec + t_loc = curr_frame.time_loc + t_ref = curr_frame.time_ref + t_total = t_feat + t_rec + t_loc + t_ref + time_results['feat'].append(t_feat) + time_results['rec'].append(t_rec) + time_results['loc'].append(t_loc) + time_results['ref'].append(t_ref) + time_results['total'].append(t_total) + + poses[scene + '/' + fn] = (curr_frame.qvec, curr_frame.tvec) + q_err, t_err = curr_frame.compute_pose_error() + if q_err <= 5 and t_err <= 0.05: + err_ths_cnt[0] = err_ths_cnt[0] + 1 + if q_err <= 2 and t_err <= 0.25: + err_ths_cnt[1] = err_ths_cnt[1] + 1 + if q_err <= 5 and t_err <= 0.5: + err_ths_cnt[2] = err_ths_cnt[2] + 1 + if q_err <= 10 and t_err <= 5: + err_ths_cnt[3] = err_ths_cnt[3] + 1 + + if success: + success_cases.append(scene + '/' + fn) + print_text = 'qname: {:s} localization success {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( + scene + '/' + fn, len(success_cases), n_total, q_err, t_err, err_ths_cnt[0], + err_ths_cnt[1], + err_ths_cnt[2], + err_ths_cnt[3], + n_total, + t_feat, t_rec, t_loc, t_ref, t_total + ) + else: + failed_cases.append(scene + '/' + fn) + print_text = 'qname: {:s} localization fail {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( + scene + '/' + fn, len(failed_cases), n_total, q_err, t_err, err_ths_cnt[0], + err_ths_cnt[1], + err_ths_cnt[2], + err_ths_cnt[3], + n_total, t_feat, t_rec, t_loc, t_ref, t_total) + print(print_text) diff --git a/imcui/third_party/pram/localization/loc_by_rec_online.py b/imcui/third_party/pram/localization/loc_by_rec_online.py new file mode 100644 index 0000000000000000000000000000000000000000..58afed6eb439b23b4a0bc7daf45d50098bcc4fc2 --- /dev/null +++ b/imcui/third_party/pram/localization/loc_by_rec_online.py @@ -0,0 +1,225 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> loc_by_rec +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 08/02/2024 15:26 +==================================================''' +import torch +import pycolmap +from localization.multimap3d import MultiMap3D +from localization.frame import Frame +import yaml, cv2, time +import numpy as np +import os.path as osp +import threading +from recognition.vis_seg import vis_seg_point, generate_color_dic +from tools.common import resize_img +from localization.viewer import Viewer +from localization.tracker import Tracker +from localization.utils import read_query_info +from tools.common import puttext_with_background + + +def loc_by_rec_online(rec_model, config, local_feat, img_transforms=None): + seg_color = generate_color_dic(n_seg=2000) + dataset_path = config['dataset_path'] + show = config['localization']['show'] + if show: + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + + locMap = MultiMap3D(config=config, save_dir=None) + if config['dataset'][0] in ['Aachen']: + viewer_config = {'scene': 'outdoor', + 'image_size_indoor': 4, + 'image_line_width_indoor': 8, } + elif config['dataset'][0] in ['C']: + viewer_config = {'scene': 'outdoor'} + elif config['dataset'][0] in ['12Scenes', '7Scenes']: + viewer_config = {'scene': 'indoor', } + else: + viewer_config = {'scene': 'outdoor', + 'image_size_indoor': 0.4, + 'image_line_width_indoor': 2, } + # start viewer + mViewer = Viewer(locMap=locMap, seg_color=seg_color, config=viewer_config) + mViewer.refinement = locMap.do_refinement + # locMap.viewer = mViewer + viewer_thread = threading.Thread(target=mViewer.run) + viewer_thread.start() + + # start tracker + mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config) + + dataset_name = config['dataset'][0] + all_scene_query_info = {} + with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + + # multiple scenes in a single dataset + err_ths_cnt = [0, 0, 0, 0] + + show_time = -1 + scenes = scene_config['scenes'] + n_total = 0 + for scene in scenes: + if len(config['localization']['loc_scene_name']) > 0: + if scene not in config['localization']['loc_scene_name']: + continue + + query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path']) + query_info = read_query_info(query_fn=query_path) + all_scene_query_info[dataset_name + '/' + scene] = query_info + image_path = osp.join(dataset_path, dataset_name, scene) + for fn in sorted(query_info.keys()): + # for fn in sorted(query_info.keys())[880:][::5]: # darwinRGB-loc-outdoor-aligned + # for fn in sorted(query_info.keys())[3161:][::5]: # darwinRGB-loc-indoor-aligned + # for fn in sorted(query_info.keys())[2840:][::5]: # darwinRGB-loc-indoor-aligned + + # for fn in sorted(query_info.keys())[2100:][::5]: # darwinRGB-loc-outdoor + # for fn in sorted(query_info.keys())[4360:][::5]: # darwinRGB-loc-indoor + # for fn in sorted(query_info.keys())[1380:]: # Cam-Church + # for fn in sorted(query_info.keys())[::5]: #ACUED-test2 + # for fn in sorted(query_info.keys())[1260:]: # jesus aligned + # for fn in sorted(query_info.keys())[1260:]: # jesus aligned + # for fn in sorted(query_info.keys())[4850:]: + img = cv2.imread(osp.join(image_path, fn)) # BGR + + camera_model, width, height, params = all_scene_query_info[dataset_name + '/' + scene][fn] + # camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params) + camera = pycolmap.Camera(model=camera_model, width=int(width), height=int(height), params=params) + curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=dataset_name + '/' + scene) + gt_sub_map = locMap.sub_maps[curr_frame.scene_name] + if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys(): + curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec'] + curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec'] + + with torch.no_grad(): + if config['image_dim'] == 1: + img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img_cuda = torch.from_numpy(img_gray / 255)[None].cuda().float() + else: + img_cuda = torch.from_numpy(img / 255).permute(2, 0, 1).cuda().float() + if img_transforms is not None: + img_cuda = img_transforms(img_cuda)[None] + else: + img_cuda = img_cuda[None] + + t_start = time.time() + encoder_out = local_feat.extract_local_global(data={'image': img_cuda}, + config={'min_keypoints': 128, + 'max_keypoints': config['eval_max_keypoints'], + } + ) + t_feat = time.time() - t_start + # global_descriptors_cuda = encoder_out['global_descriptors'] + scores_cuda = encoder_out['scores'][0][None] + kpts_cuda = encoder_out['keypoints'][0][None] + descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1) + + curr_frame.add_keypoints(keypoints=np.hstack([kpts_cuda[0].cpu().numpy(), + scores_cuda[0].cpu().numpy().reshape(-1, 1)]), + descriptors=descriptors_cuda[0].cpu().numpy()) + curr_frame.time_feat = t_feat + + t_start = time.time() + _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'], + semi_descs=encoder_out['mid_features'], + kpts=kpts_cuda[0], + norm_desc=config['norm_desc']) + rec_out = rec_model({'scores': scores_cuda, + 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1), + 'keypoints': kpts_cuda, + 'image': img_cuda}) + t_rec = time.time() - t_start + curr_frame.time_rec = t_rec + + pred = { + 'scores': scores_cuda, + 'keypoints': kpts_cuda, + 'descriptors': descriptors_cuda, + # 'global_descriptors': global_descriptors_cuda, + 'image_size': np.array([img.shape[1], img.shape[0]])[None], + } + + pred = {**pred, **rec_out} + pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C] + + pred_seg = pred_seg[0].cpu().numpy() + kpts = kpts_cuda[0].cpu().numpy() + segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C] + curr_frame.add_segmentations(segmentations=segmentations, + filtering_threshold=config['localization']['pre_filtering_th']) + + img_pred_seg = vis_seg_point(img=img, kpts=curr_frame.keypoints, + segs=curr_frame.seg_ids + 1, seg_color=seg_color, radius=9) + show_text = 'kpts: {:d}'.format(kpts.shape[0]) + img_pred_seg = cv2.putText(img=img_pred_seg, + text=show_text, + org=(50, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + curr_frame.image_rec = img_pred_seg + + if show: + img_text = puttext_with_background(image=img, text='Press C - continue | S - pause | Q - exit', + org=(30, 50), + bg_color=(255, 255, 255), + text_color=(0, 0, 255), + fontScale=1, thickness=2) + cv2.imshow('img', img_text) + key = cv2.waitKey(show_time) + if key == ord('q'): + exit(0) + elif key == ord('s'): + show_time = -1 + elif key == ord('c'): + show_time = 1 + + # Step1: do tracker first + success = not mTracker.lost and mViewer.tracking + if success: + success = mTracker.run(frame=curr_frame) + if success: + mViewer.update(curr_frame=curr_frame) + + if not success: + # success = locMap.run(q_frame=curr_frame, q_segs=segmentations) + success = locMap.run(q_frame=curr_frame) + if success: + mViewer.update(curr_frame=curr_frame) + + if success: + curr_frame.update_point3ds() + if mViewer.tracking: + mTracker.lost = False + mTracker.last_frame = curr_frame + + time.sleep(50 / 1000) + locMap.do_refinement = mViewer.refinement + + n_total = n_total + 1 + q_err, t_err = curr_frame.compute_pose_error() + if q_err <= 5 and t_err <= 0.05: + err_ths_cnt[0] = err_ths_cnt[0] + 1 + if q_err <= 2 and t_err <= 0.25: + err_ths_cnt[1] = err_ths_cnt[1] + 1 + if q_err <= 5 and t_err <= 0.5: + err_ths_cnt[2] = err_ths_cnt[2] + 1 + if q_err <= 10 and t_err <= 5: + err_ths_cnt[3] = err_ths_cnt[3] + 1 + time_total = curr_frame.time_feat + curr_frame.time_rec + curr_frame.time_loc + curr_frame.time_ref + print_text = 'qname: {:s} localization {:b}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( + scene + '/' + fn, success, q_err, t_err, + err_ths_cnt[0], + err_ths_cnt[1], + err_ths_cnt[2], + err_ths_cnt[3], + n_total, + curr_frame.time_feat, curr_frame.time_rec, curr_frame.time_loc, curr_frame.time_ref, time_total + ) + print(print_text) + + mViewer.terminate() + viewer_thread.join() diff --git a/imcui/third_party/pram/localization/localizer.py b/imcui/third_party/pram/localization/localizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0777b9cc6d7f70aa8c3699f360684cd24054a488 --- /dev/null +++ b/imcui/third_party/pram/localization/localizer.py @@ -0,0 +1,217 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> hloc +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 16:45 +==================================================''' + +import os +import os.path as osp +from tqdm import tqdm +import argparse +import time +import logging +import h5py +import numpy as np +from pathlib import Path +from colmap_utils.read_write_model import read_model +from colmap_utils.parsers import parse_image_lists_with_intrinsics +# localization +from localization.match_features_batch import confs +from localization.base_model import dynamic_load +from localization import matchers +from localization.utils import compute_pose_error, read_gt_pose, read_retrieval_results +from localization.pose_estimator import pose_estimator_hloc, pose_estimator_iterative + + +def run(args): + if args.gt_pose_fn is not None: + gt_poses = read_gt_pose(path=args.gt_pose_fn) + else: + gt_poses = {} + retrievals = read_retrieval_results(args.retrieval) + + save_root = args.save_root # path to save + os.makedirs(save_root, exist_ok=True) + matcher_name = args.matcher_method # matching method + print('matcher: ', confs[args.matcher_method]['model']['name']) + Model = dynamic_load(matchers, confs[args.matcher_method]['model']['name']) + matcher = Model(confs[args.matcher_method]['model']).eval().cuda() + + local_feat_name = args.features.as_posix().split("/")[-1].split(".")[0] # name of local features + save_fn = '{:s}_{:s}'.format(local_feat_name, matcher_name) + if args.use_hloc: + save_fn = 'hloc_' + save_fn + save_fn = osp.join(save_root, save_fn) + + queries = parse_image_lists_with_intrinsics(args.queries) + _, db_images, points3D = read_model(str(args.reference_sfm), '.bin') + db_name_to_id = {image.name: i for i, image in db_images.items()} + feature_file = h5py.File(args.features, 'r') + + tag = '' + if args.do_covisible_opt: + tag = tag + "_o" + str(int(args.obs_thresh)) + 'op' + str(int(args.covisibility_frame)) + tag = tag + "th" + str(int(args.opt_thresh)) + if args.iters > 0: + tag = tag + "i" + str(int(args.iters)) + + log_fn = save_fn + tag + vis_dir = save_fn + tag + results = save_fn + tag + + full_log_fn = log_fn + '_full.log' + loc_log_fn = log_fn + '_loc.npy' + results = Path(results + '.txt') + vis_dir = Path(vis_dir) + if vis_dir is not None: + Path(vis_dir).mkdir(exist_ok=True) + print("save_fn: ", log_fn) + + logging.info('Starting localization...') + poses = {} + failed_cases = [] + n_total = 0 + n_failed = 0 + full_log_info = '' + loc_results = {} + + error_ths = ((0.25, 2), (0.5, 5), (5, 10)) + success = [0, 0, 0] + total_loc_time = [] + + for qname, qinfo in tqdm(queries): + kpq = feature_file[qname]['keypoints'].__array__() + n_total += 1 + time_start = time.time() + + if qname in retrievals.keys(): + cans = retrievals[qname] + db_ids = [db_name_to_id[v] for v in cans] + else: + cans = [] + db_ids = [] + time_coarse = time.time() + + if args.use_hloc: + output = pose_estimator_hloc(qname=qname, qinfo=qinfo, db_ids=db_ids, db_images=db_images, + points3D=points3D, + feature_file=feature_file, + thresh=args.ransac_thresh, + image_dir=args.image_dir, + matcher=matcher, + log_info='', + query_img_prefix='', + db_img_prefix='') + else: # should be faster and more accurate than hloc + t_start = time.time() + output = pose_estimator_iterative(qname=qname, + qinfo=qinfo, + matcher=matcher, + db_ids=db_ids, + db_images=db_images, + points3D=points3D, + feature_file=feature_file, + thresh=args.ransac_thresh, + image_dir=args.image_dir, + do_covisibility_opt=args.do_covisible_opt, + covisibility_frame=args.covisibility_frame, + log_info='', + inlier_th=args.inlier_thresh, + obs_th=args.obs_thresh, + opt_th=args.opt_thresh, + gt_qvec=gt_poses[qname]['qvec'] if qname in gt_poses.keys() else None, + gt_tvec=gt_poses[qname]['tvec'] if qname in gt_poses.keys() else None, + query_img_prefix='', + db_img_prefix='database', + ) + time_full = time.time() + + qvec = output['qvec'] + tvec = output['tvec'] + loc_time = time_full - time_start + total_loc_time.append(loc_time) + + poses[qname] = (qvec, tvec) + print_text = "All {:d}/{:d} failed cases, time[cs/fn]: {:.2f}/{:.2f}".format( + n_failed, n_total, + time_coarse - time_start, + time_full - time_coarse, + ) + + if qname in gt_poses.keys(): + gt_qvec = gt_poses[qname]['qvec'] + gt_tvec = gt_poses[qname]['tvec'] + + q_error, t_error = compute_pose_error(pred_qcw=qvec, pred_tcw=tvec, gt_qcw=gt_qvec, gt_tcw=gt_tvec) + + for error_idx, th in enumerate(error_ths): + if t_error <= th[0] and q_error <= th[1]: + success[error_idx] += 1 + print_text += ( + ', q_error:{:.2f} t_error:{:.2f} {:d}/{:d}/{:d}/{:d}, time: {:.2f}, {:d}pts'.format(q_error, t_error, + success[0], + success[1], + success[2], n_total, + loc_time, + kpq.shape[0])) + if output['num_inliers'] == 0: + failed_cases.append(qname) + + loc_results[qname] = { + 'keypoints_query': output['keypoints_query'], + 'points3D_ids': output['points3D_ids'], + } + full_log_info = full_log_info + output['log_info'] + full_log_info += (print_text + "\n") + print(print_text) + + logs_path = f'{results}.failed' + with open(logs_path, 'w') as f: + for v in failed_cases: + print(v) + f.write(v + "\n") + + logging.info(f'Localized {len(poses)} / {len(queries)} images.') + logging.info(f'Writing poses to {results}...') + # logging.info(f'Mean loc time: {np.mean(total_loc_time)}...') + print('Mean loc time: {:.2f}...'.format(np.mean(total_loc_time))) + with open(results, 'w') as f: + for q in poses: + qvec, tvec = poses[q] + qvec = ' '.join(map(str, qvec)) + tvec = ' '.join(map(str, tvec)) + name = q + f.write(f'{name} {qvec} {tvec}\n') + + with open(full_log_fn, 'w') as f: + f.write(full_log_info) + + np.save(loc_log_fn, loc_results) + print('Save logs to ', loc_log_fn) + logging.info('Done!') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image_dir', type=str, required=True) + parser.add_argument('--dataset', type=str, required=True) + parser.add_argument('--reference_sfm', type=Path, required=True) + parser.add_argument('--queries', type=Path, required=True) + parser.add_argument('--features', type=Path, required=True) + parser.add_argument('--ransac_thresh', type=float, default=12) + parser.add_argument('--covisibility_frame', type=int, default=50) + parser.add_argument('--do_covisible_opt', action='store_true') + parser.add_argument('--use_hloc', action='store_true') + parser.add_argument('--matcher_method', type=str, default="NNM") + parser.add_argument('--inlier_thresh', type=int, default=50) + parser.add_argument('--obs_thresh', type=float, default=3) + parser.add_argument('--opt_thresh', type=float, default=12) + parser.add_argument('--save_root', type=str, required=True) + parser.add_argument('--retrieval', type=Path, default=None) + parser.add_argument('--gt_pose_fn', type=str, default=None) + + args = parser.parse_args() + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + run(args=args) diff --git a/imcui/third_party/pram/localization/match_features.py b/imcui/third_party/pram/localization/match_features.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1b4edccff67db24d97fadb47024eb09c026ce8 --- /dev/null +++ b/imcui/third_party/pram/localization/match_features.py @@ -0,0 +1,156 @@ +import argparse +import torch +from pathlib import Path +import h5py +import logging +from tqdm import tqdm +import pprint + +import localization.matchers as matchers +from localization.base_model import dynamic_load +from colmap_utils.parsers import names_to_pair + +confs = { + 'gm': { + 'output': 'gm', + 'model': { + 'name': 'gm', + 'weight_path': 'weights/imp_gm.900.pth', + 'sinkhorn_iterations': 20, + }, + }, + 'gml': { + 'output': 'gml', + 'model': { + 'name': 'gml', + 'weight_path': 'weights/imp_gml.920.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'adagml': { + 'output': 'adagml', + 'model': { + 'name': 'adagml', + 'weight_path': 'weights/imp_adagml.80.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'superglue': { + 'output': 'superglue', + 'model': { + 'name': 'superglue', + 'weights': 'outdoor', + 'sinkhorn_iterations': 20, + 'weight_path': 'weights/superglue_outdoor.pth', + }, + }, + 'NNM': { + 'output': 'NNM', + 'model': { + 'name': 'nearest_neighbor', + 'do_mutual_check': True, + 'distance_threshold': None, + }, + }, +} + + +@torch.no_grad() +def main(conf, pairs, features, export_dir, exhaustive=False): + logging.info('Matching local features with configuration:' + f'\n{pprint.pformat(conf)}') + + feature_path = Path(export_dir, features + '.h5') + assert feature_path.exists(), feature_path + feature_file = h5py.File(str(feature_path), 'r') + pairs_name = pairs.stem + if not exhaustive: + assert pairs.exists(), pairs + with open(pairs, 'r') as f: + pair_list = f.read().rstrip('\n').split('\n') + elif exhaustive: + logging.info(f'Writing exhaustive match pairs to {pairs}.') + assert not pairs.exists(), pairs + + # get the list of images from the feature file + images = [] + feature_file.visititems( + lambda name, obj: images.append(obj.parent.name.strip('/')) + if isinstance(obj, h5py.Dataset) else None) + images = list(set(images)) + + pair_list = [' '.join((images[i], images[j])) + for i in range(len(images)) for j in range(i)] + with open(str(pairs), 'w') as f: + f.write('\n'.join(pair_list)) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Model = dynamic_load(matchers, conf['model']['name']) + model = Model(conf['model']).eval().to(device) + + match_name = f'{features}-{conf["output"]}-{pairs_name}' + match_path = Path(export_dir, match_name + '.h5') + + match_file = h5py.File(str(match_path), 'a') + + matched = set() + for pair in tqdm(pair_list, smoothing=.1): + name0, name1 = pair.split(' ') + pair = names_to_pair(name0, name1) + + # Avoid to recompute duplicates to save time + if len({(name0, name1), (name1, name0)} & matched) \ + or pair in match_file: + continue + + data = {} + feats0, feats1 = feature_file[name0], feature_file[name1] + for k in feats1.keys(): + # data[k + '0'] = feats0[k].__array__() + if k == 'descriptors': + data[k + '0'] = feats0[k][()].transpose() # [N D] + else: + data[k + '0'] = feats0[k][()] + for k in feats1.keys(): + # data[k + '1'] = feats1[k].__array__() + # data[k + '1'] = feats1[k][()].transpose() # [N D] + if k == 'descriptors': + data[k + '1'] = feats1[k][()].transpose() # [N D] + else: + data[k + '1'] = feats1[k][()] + data = {k: torch.from_numpy(v)[None].float().to(device) + for k, v in data.items()} + + # some matchers might expect an image but only use its size + data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1]) + data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1]) + + pred = model(data) + grp = match_file.create_group(pair) + matches = pred['matches0'][0].cpu().short().numpy() + grp.create_dataset('matches0', data=matches) + + if 'matching_scores0' in pred: + scores = pred['matching_scores0'][0].cpu().half().numpy() + grp.create_dataset('matching_scores0', data=scores) + + matched |= {(name0, name1), (name1, name0)} + + match_file.close() + logging.info('Finished exporting matches.') + + return match_path + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--features', type=str, required=True) + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) + parser.add_argument('--exhaustive', action='store_true') + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir, + exhaustive=args.exhaustive) diff --git a/imcui/third_party/pram/localization/match_features_batch.py b/imcui/third_party/pram/localization/match_features_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0dc9d4a1e4288892c365616e45304a19e93c3e --- /dev/null +++ b/imcui/third_party/pram/localization/match_features_batch.py @@ -0,0 +1,242 @@ +import argparse +import torch +from pathlib import Path +import h5py +import logging +from tqdm import tqdm +import pprint +from queue import Queue +from threading import Thread +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import localization.matchers as matchers +from localization.base_model import dynamic_load +from colmap_utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval + +confs = { + 'gm': { + 'output': 'gm', + 'model': { + 'name': 'gm', + 'weight_path': 'weights/imp_gm.900.pth', + 'sinkhorn_iterations': 20, + }, + }, + 'gml': { + 'output': 'gml', + 'model': { + 'name': 'gml', + 'weight_path': 'weights/imp_gml.920.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'adagml': { + 'output': 'adagml', + 'model': { + 'name': 'adagml', + 'weight_path': 'weights/imp_adagml.80.pth', + 'sinkhorn_iterations': 20, + }, + }, + + 'superglue': { + 'output': 'superglue', + 'model': { + 'name': 'superglue', + 'weights': 'outdoor', + 'sinkhorn_iterations': 20, + 'weight_path': 'weights/superglue_outdoor.pth', + }, + }, + 'NNM': { + 'output': 'NNM', + 'model': { + 'name': 'nearest_neighbor', + 'do_mutual_check': True, + 'distance_threshold': None, + }, + }, +} + + +class WorkQueue: + def __init__(self, work_fn, num_threads=1): + self.queue = Queue(num_threads) + self.threads = [ + Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads) + ] + for thread in self.threads: + thread.start() + + def join(self): + for thread in self.threads: + self.queue.put(None) + for thread in self.threads: + thread.join() + + def thread_fn(self, work_fn): + item = self.queue.get() + while item is not None: + work_fn(item) + item = self.queue.get() + + def put(self, data): + self.queue.put(data) + + +class FeaturePairsDataset(torch.utils.data.Dataset): + def __init__(self, pairs, feature_path_q, feature_path_r): + self.pairs = pairs + self.feature_path_q = feature_path_q + self.feature_path_r = feature_path_r + + def __getitem__(self, idx): + name0, name1 = self.pairs[idx] + data = {} + with h5py.File(self.feature_path_q, "r") as fd: + grp = fd[name0] + for k, v in grp.items(): + data[k + "0"] = torch.from_numpy(v.__array__()).float() + if k == 'descriptors': + data[k + '0'] = data[k + '0'].t() + # some matchers might expect an image but only use its size + data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + with h5py.File(self.feature_path_r, "r") as fd: + grp = fd[name1] + for k, v in grp.items(): + data[k + "1"] = torch.from_numpy(v.__array__()).float() + if k == 'descriptors': + data[k + '1'] = data[k + '1'].t() + data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1]) + return data + + def __len__(self): + return len(self.pairs) + + +def writer_fn(inp, match_path): + pair, pred = inp + with h5py.File(str(match_path), "a", libver="latest") as fd: + if pair in fd: + del fd[pair] + grp = fd.create_group(pair) + matches = pred["matches0"][0].cpu().short().numpy() + grp.create_dataset("matches0", data=matches) + if "matching_scores0" in pred: + scores = pred["matching_scores0"][0].cpu().half().numpy() + grp.create_dataset("matching_scores0", data=scores) + + +def main( + conf: Dict, + pairs: Path, + features: Union[Path, str], + export_dir: Optional[Path] = None, + matches: Optional[Path] = None, + features_ref: Optional[Path] = None, + overwrite: bool = False, +) -> Path: + if isinstance(features, Path) or Path(features).exists(): + features_q = features + if matches is None: + raise ValueError( + "Either provide both features and matches as Path" " or both as names." + ) + else: + if export_dir is None: + raise ValueError( + "Provide an export_dir if features is not" f" a file path: {features}." + ) + features_q = Path(export_dir, features + ".h5") + if matches is None: + matches = Path(export_dir, f'{features}-{conf["output"]}-{pairs.stem}.h5') + + if features_ref is None: + features_ref = features_q + match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite) + + return matches + + +def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None): + """Avoid to recompute duplicates to save time.""" + pairs = set() + for i, j in pairs_all: + if (j, i) not in pairs: + pairs.add((i, j)) + pairs = list(pairs) + if match_path is not None and match_path.exists(): + with h5py.File(str(match_path), "r", libver="latest") as fd: + pairs_filtered = [] + for i, j in pairs: + if ( + names_to_pair(i, j) in fd + or names_to_pair(j, i) in fd + or names_to_pair_old(i, j) in fd + or names_to_pair_old(j, i) in fd + ): + continue + pairs_filtered.append((i, j)) + return pairs_filtered + return pairs + + +@torch.no_grad() +def match_from_paths( + conf: Dict, + pairs_path: Path, + match_path: Path, + feature_path_q: Path, + feature_path_ref: Path, + overwrite: bool = False, +) -> Path: + logging.info( + "Matching local features with configuration:" f"\n{pprint.pformat(conf)}" + ) + + if not feature_path_q.exists(): + raise FileNotFoundError(f"Query feature file {feature_path_q}.") + if not feature_path_ref.exists(): + raise FileNotFoundError(f"Reference feature file {feature_path_ref}.") + match_path.parent.mkdir(exist_ok=True, parents=True) + + assert pairs_path.exists(), pairs_path + pairs = parse_retrieval(pairs_path) + pairs = [(q, r) for q, rs in pairs.items() for r in rs] + pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) + if len(pairs) == 0: + logging.info("Skipping the matching.") + return + + device = "cuda" if torch.cuda.is_available() else "cpu" + Model = dynamic_load(matchers, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(device) + + dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref) + loader = torch.utils.data.DataLoader( + dataset, num_workers=4, batch_size=1, shuffle=False, pin_memory=True + ) + writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5) + + for idx, data in enumerate(tqdm(loader, smoothing=0.1)): + data = { + k: v if k.startswith("image") else v.to(device, non_blocking=True) + for k, v in data.items() + } + pred = model(data) + pair = names_to_pair(*pairs[idx]) + writer_queue.put((pair, pred)) + writer_queue.join() + logging.info("Finished exporting matches.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--export_dir', type=Path, required=True) + parser.add_argument('--features', type=str, required=True) + parser.add_argument('--pairs', type=Path, required=True) + parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys())) + args = parser.parse_args() + main(confs[args.conf], args.pairs, args.features, args.export_dir) diff --git a/imcui/third_party/pram/localization/matchers/__init__.py b/imcui/third_party/pram/localization/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7edac76f912b1e5ebb0401b6cc7a5d3c64ce963a --- /dev/null +++ b/imcui/third_party/pram/localization/matchers/__init__.py @@ -0,0 +1,3 @@ +def get_matcher(matcher): + mod = __import__(f'{__name__}.{matcher}', fromlist=['']) + return getattr(mod, 'Model') diff --git a/imcui/third_party/pram/localization/matchers/adagml.py b/imcui/third_party/pram/localization/matchers/adagml.py new file mode 100644 index 0000000000000000000000000000000000000000..31a4bd2aa74bef934543b79567f148f5b8b7b092 --- /dev/null +++ b/imcui/third_party/pram/localization/matchers/adagml.py @@ -0,0 +1,41 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> adagml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 11/02/2024 14:34 +==================================================''' +import torch +from localization.base_model import BaseModel +from nets.adagml import AdaGML as GMatcher + + +class AdaGML(BaseModel): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def _init(self, conf): + self.net = GMatcher(config=conf).eval() + state_dict = torch.load(conf['weight_path'], map_location='cpu')['model'] + self.net.load_state_dict(state_dict, strict=True) + + def _forward(self, data): + with torch.no_grad(): + return self.net(data) diff --git a/imcui/third_party/pram/localization/matchers/gm.py b/imcui/third_party/pram/localization/matchers/gm.py new file mode 100644 index 0000000000000000000000000000000000000000..2484cdb521d28a8cc0b5be7148919cd46bc67b32 --- /dev/null +++ b/imcui/third_party/pram/localization/matchers/gm.py @@ -0,0 +1,44 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File r2d2 -> gm +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 25/05/2023 10:09 +==================================================''' +import torch +from localization.base_model import BaseModel +from nets.gm import GM as GMatcher + + +class GM(BaseModel): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def _init(self, conf): + self.net = GMatcher(config=conf).eval() + state_dict = torch.load(conf['weight_path'], map_location='cpu')['model'] + self.net.load_state_dict(state_dict, strict=True) + + def _forward(self, data): + with torch.no_grad(): + return self.net(data) diff --git a/imcui/third_party/pram/localization/matchers/gml.py b/imcui/third_party/pram/localization/matchers/gml.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9acdeaf3c7bd9670c1f7c49e2bbf709f1e8b4a --- /dev/null +++ b/imcui/third_party/pram/localization/matchers/gml.py @@ -0,0 +1,45 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> gml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 15/01/2024 11:01 +==================================================''' +import torch +from localization.base_model import BaseModel +from nets.gml import GML as GMatcher + + +class GML(BaseModel): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def _init(self, conf): + self.net = GMatcher(config=conf).eval() + state_dict = torch.load(conf['weight_path'], map_location='cpu')['model'] + self.net.load_state_dict(state_dict, strict=True) + + def _forward(self, data): + with torch.no_grad(): + # print(data['keypoints0'].shape, data['descriptors0'].shape, data['image0'].shape) + return self.net(data) diff --git a/imcui/third_party/pram/localization/matchers/nearest_neighbor.py b/imcui/third_party/pram/localization/matchers/nearest_neighbor.py new file mode 100644 index 0000000000000000000000000000000000000000..42b8078747535a269dab6131b4f20c0857c36c03 --- /dev/null +++ b/imcui/third_party/pram/localization/matchers/nearest_neighbor.py @@ -0,0 +1,56 @@ +import torch +from localization.base_model import BaseModel + + +def find_nn(sim, ratio_thresh, distance_thresh): + sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True) + dist_nn = 2 * (1 - sim_nn) + mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device) + if ratio_thresh: + mask = mask & (dist_nn[..., 0] <= (ratio_thresh ** 2) * dist_nn[..., 1]) + if distance_thresh: + mask = mask & (dist_nn[..., 0] <= distance_thresh ** 2) + matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1)) + scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0)) + return matches, scores + + +def mutual_check(m0, m1): + inds0 = torch.arange(m0.shape[-1], device=m0.device) + loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0))) + ok = (m0 > -1) & (inds0 == loop) + m0_new = torch.where(ok, m0, m0.new_tensor(-1)) + return m0_new + + +class NearestNeighbor(BaseModel): + default_conf = { + 'ratio_threshold': None, + 'distance_threshold': None, + 'do_mutual_check': True, + } + required_inputs = ['descriptors0', 'descriptors1'] + + def _init(self, conf): + pass + + def _forward(self, data): + sim = torch.einsum( + 'bdn,bdm->bnm', data['descriptors0'], data['descriptors1']) + matches0, scores0 = find_nn( + sim, self.conf['ratio_threshold'], self.conf['distance_threshold']) + # matches1, scores1 = find_nn( + # sim.transpose(1, 2), self.conf['ratio_threshold'], + # self.conf['distance_threshold']) + if self.conf['do_mutual_check']: + # print("with mutual check") + matches1, scores1 = find_nn( + sim.transpose(1, 2), self.conf['ratio_threshold'], + self.conf['distance_threshold']) + matches0 = mutual_check(matches0, matches1) + # else: + # print("no mutual check") + return { + 'matches0': matches0, + 'matching_scores0': scores0, + } diff --git a/imcui/third_party/pram/localization/multimap3d.py b/imcui/third_party/pram/localization/multimap3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6100b4f4bfeb1d3f8bc94598723979e830bf4172 --- /dev/null +++ b/imcui/third_party/pram/localization/multimap3d.py @@ -0,0 +1,379 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> multimap3d +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 13:47 +==================================================''' +import numpy as np +import os +import os.path as osp +import time +import cv2 +import torch +import yaml +from copy import deepcopy +from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches +from localization.base_model import dynamic_load +import localization.matchers as matchers +from localization.match_features_batch import confs as matcher_confs +from nets.gm import GM +from tools.common import resize_img +from localization.singlemap3d import SingleMap3D +from localization.frame import Frame + + +class MultiMap3D: + def __init__(self, config, viewer=None, save_dir=None): + self.config = config + self.save_dir = save_dir + + self.scenes = [] + self.sid_scene_name = [] + self.sub_maps = {} + self.scene_name_start_sid = {} + + self.loc_config = config['localization'] + self.save_dir = save_dir + if self.save_dir is not None: + os.makedirs(self.save_dir, exist_ok=True) + + self.matching_method = config['localization']['matching_method'] + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Model = dynamic_load(matchers, self.matching_method) + self.matcher = Model(matcher_confs[self.matching_method]['model']).eval().to(device) + + self.initialize_map(config=config) + self.loc_config = config['localization'] + + self.viewer = viewer + + # options + self.do_refinement = self.loc_config['do_refinement'] + self.refinement_method = self.loc_config['refinement_method'] + self.semantic_matching = self.loc_config['semantic_matching'] + self.do_pre_filtering = self.loc_config['pre_filtering_th'] > 0 + self.pre_filtering_th = self.loc_config['pre_filtering_th'] + + def initialize_map(self, config): + n_class = 0 + datasets = config['dataset'] + + for name in datasets: + config_path = osp.join(config['config_path'], '{:s}.yaml'.format(name)) + dataset_name = name + + with open(config_path, 'r') as f: + scene_config = yaml.load(f, Loader=yaml.Loader) + + scenes = scene_config['scenes'] + for sid, scene in enumerate(scenes): + self.scenes.append(name + '/' + scene) + + new_config = deepcopy(config) + new_config['dataset_path'] = osp.join(config['dataset_path'], dataset_name, scene) + new_config['landmark_path'] = osp.join(config['landmark_path'], dataset_name, scene) + new_config['n_cluster'] = scene_config[scene]['n_cluster'] + new_config['cluster_mode'] = scene_config[scene]['cluster_mode'] + new_config['cluster_method'] = scene_config[scene]['cluster_method'] + new_config['gt_pose_path'] = scene_config[scene]['gt_pose_path'] + new_config['image_path_prefix'] = scene_config[scene]['image_path_prefix'] + sub_map = SingleMap3D(config=new_config, + matcher=self.matcher, + with_compress=config['localization']['with_compress'], + start_sid=n_class) + self.sub_maps[dataset_name + '/' + scene] = sub_map + + n_scene_class = scene_config[scene]['n_cluster'] + self.sid_scene_name = self.sid_scene_name + [dataset_name + '/' + scene for ni in range(n_scene_class)] + self.scene_name_start_sid[dataset_name + '/' + scene] = n_class + n_class = n_class + n_scene_class + + # break + print('Load {} sub_maps from {} datasets'.format(len(self.sub_maps), len(datasets))) + + def run(self, q_frame: Frame): + show = self.loc_config['show'] + seg_color = generate_color_dic(n_seg=2000) + if show: + cv2.namedWindow('loc', cv2.WINDOW_NORMAL) + + q_loc_segs = self.process_segmentations(segs=torch.from_numpy(q_frame.segmentations), + topk=self.loc_config['seg_k']) + q_pred_segs_top1 = q_frame.seg_ids # initial results + + q_scene_name = q_frame.scene_name + q_name = q_frame.name + q_full_name = osp.join(q_scene_name, q_name) + + q_loc_sids = {} + for v in q_loc_segs: + q_loc_sids[v[0]] = (v[1], v[2]) + query_sids = list(q_loc_sids.keys()) + + for i, sid in enumerate(query_sids): + t_start = time.time() + q_kpt_ids = q_loc_sids[sid][0] + print(q_scene_name, q_name, sid) + + sid = sid - 1 # start from 0, confused! + + pred_scene_name = self.sid_scene_name[sid] + start_seg_id = self.scene_name_start_sid[pred_scene_name] + pred_sid_in_sub_scene = sid - self.scene_name_start_sid[pred_scene_name] + pred_sub_map = self.sub_maps[pred_scene_name] + pred_image_path_prefix = pred_sub_map.image_path_prefix + + print('pred/gt scene: {:s}, {:s}, sid: {:d}'.format(pred_scene_name, q_scene_name, pred_sid_in_sub_scene)) + print('{:s}/{:s}, pred: {:s}, sid: {:d}, order: {:d}'.format(q_scene_name, q_name, pred_scene_name, sid, + i)) + + if (q_kpt_ids.shape[0] >= self.loc_config['min_kpts'] + and self.semantic_matching + and pred_sub_map.check_semantic_consistency(q_frame=q_frame, + sid=pred_sid_in_sub_scene, + overlap_ratio=0.5)): + semantic_matching = True + else: + q_kpt_ids = np.arange(q_frame.keypoints.shape[0]) + semantic_matching = False + print_text = f'Semantic matching - {semantic_matching}! Query kpts {q_kpt_ids.shape[0]} for {i}th seg {sid}' + print(print_text) + ret = pred_sub_map.localize_with_ref_frame(q_frame=q_frame, + q_kpt_ids=q_kpt_ids, + sid=pred_sid_in_sub_scene, + semantic_matching=semantic_matching) + + q_frame.time_loc = q_frame.time_loc + time.time() - t_start # accumulate tracking time + + if show: + reference_frame = pred_sub_map.reference_frames[ret['reference_frame_id']] + ref_img = cv2.imread(osp.join(self.config['dataset_path'], pred_scene_name, pred_image_path_prefix, + reference_frame.name)) + q_img_seg = vis_seg_point(img=q_frame.image, kpts=q_frame.keypoints[q_kpt_ids, :2], + segs=q_frame.seg_ids[q_kpt_ids] + 1, + seg_color=seg_color) + matched_points3D_ids = ret['matched_point3D_ids'] + ref_sids = np.array([pred_sub_map.point3Ds[v].seg_id for v in matched_points3D_ids]) + \ + self.scene_name_start_sid[pred_scene_name] + 1 # start from 1 as bg is 0 + ref_img_seg = vis_seg_point(img=ref_img, kpts=ret['matched_ref_keypoints'], segs=ref_sids, + seg_color=seg_color) + q_matched_kpts = ret['matched_keypoints'] + ref_matched_kpts = ret['matched_ref_keypoints'] + img_loc_matching = plot_matches(img1=q_img_seg, img2=ref_img_seg, + pts1=q_matched_kpts, pts2=ref_matched_kpts, + inliers=np.array([True for i in range(q_matched_kpts.shape[0])]), + radius=9, line_thickness=3 + ) + + q_frame.image_matching_tmp = img_loc_matching + q_frame.reference_frame_name_tmp = osp.join(self.config['dataset_path'], + pred_scene_name, + pred_image_path_prefix, + reference_frame.name) + # ret['image_matching'] = img_loc_matching + # ret['reference_frame_name'] = osp.join(self.config['dataset_path'], + # pred_scene_name, + # pred_image_path_prefix, + # reference_frame.name) + q_ref_img_matching = np.hstack([resize_img(q_img_seg, nh=512), + resize_img(ref_img_seg, nh=512), + resize_img(img_loc_matching, nh=512)]) + + ret['order'] = i + ret['matched_scene_name'] = pred_scene_name + if not ret['success']: + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + print_text = f'Localization failed with {num_matches}/{q_kpt_ids.shape[0]} matches and {num_inliers} inliers, order {i}' + print(print_text) + + if show: + show_text = 'FAIL! order: {:d}/{:d}-{:d}/{:d}'.format(i, len(q_loc_segs), + num_matches, + q_kpt_ids.shape[0]) + q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], + radius=9 + 2, thickness=2) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + q_frame.image_inlier_tmp = q_img_inlier + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + continue + + if show: + q_err, t_err = q_frame.compute_pose_error() + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + show_text = 'order: {:d}/{:d}, k/m/i: {:d}/{:d}/{:d}'.format( + i, len(q_loc_segs), q_kpt_ids.shape[0], num_matches, num_inliers) + q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], + radius=9 + 2, thickness=2) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + q_frame.image_inlier_tmp = q_img_inlier + + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + + success = self.verify_and_update(q_frame=q_frame, ret=ret) + + if not success: + continue + else: + break + + if q_frame.tracking_status is None: + print('Failed to find a proper reference frame.') + return False + + # do refinement + if not self.do_refinement: + return True + else: + t_start = time.time() + pred_sub_map = self.sub_maps[q_frame.matched_scene_name] + if q_frame.tracking_status is True and np.sum(q_frame.matched_inliers) >= 64: + ret = pred_sub_map.refine_pose(q_frame=q_frame, refinement_method=self.loc_config['refinement_method']) + else: + ret = pred_sub_map.refine_pose(q_frame=q_frame, + refinement_method='matching') # do not trust the pose for projection + + q_frame.time_ref = time.time() - t_start + + inlier_mask = np.array(ret['inliers']) + + q_frame.qvec = ret['qvec'] + q_frame.tvec = ret['tvec'] + q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] + q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] + q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] + q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] + q_frame.matched_sids = ret['matched_sids'][inlier_mask] + q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] + + q_frame.refinement_reference_frame_ids = ret['refinement_reference_frame_ids'] + q_frame.reference_frame_id = ret['reference_frame_id'] + + q_err, t_err = q_frame.compute_pose_error() + ref_full_name = q_frame.matched_scene_name + '/' + pred_sub_map.reference_frames[ + q_frame.reference_frame_id].name + print_text = 'Localization of {:s} success with inliers {:d}/{:d} with ref_name: {:s}, order: {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( + q_full_name, ret['num_inliers'], len(ret['inliers']), ref_full_name, q_frame.matched_order, q_err, + t_err) + print(print_text) + + if show: + q_err, t_err = q_frame.compute_pose_error() + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + show_text = 'Ref:{:d}/{:d},r_err:{:.2f}/t_err:{:.2f}'.format(num_matches, num_inliers, q_err, + t_err) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 130), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + q_frame.image_inlier = q_img_inlier + + return True + + def verify_and_update(self, q_frame: Frame, ret: dict): + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + if q_frame.matched_keypoints is None or np.sum(q_frame.matched_inliers) < num_inliers: + self.update_query_frame(q_frame=q_frame, ret=ret) + + q_err, t_err = q_frame.compute_pose_error(pred_qvec=ret['qvec'], pred_tvec=ret['tvec']) + + if num_inliers < self.loc_config['min_inliers']: + print_text = 'Failed due to insufficient {:d} inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( + ret['num_inliers'], ret['order'], q_err, t_err) + print(print_text) + q_frame.tracking_status = False + return False + else: + print_text = 'Succeed! Find {}/{} 2D-3D inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( + num_inliers, num_matches, ret['order'], q_err, t_err) + print(print_text) + q_frame.tracking_status = True + return True + + def update_query_frame(self, q_frame, ret): + q_frame.matched_scene_name = ret['matched_scene_name'] + q_frame.reference_frame_id = ret['reference_frame_id'] + q_frame.qvec = ret['qvec'] + q_frame.tvec = ret['tvec'] + + inlier_mask = np.array(ret['inliers']) + q_frame.matched_keypoints = ret['matched_keypoints'] + q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'] + q_frame.matched_xyzs = ret['matched_xyzs'] + q_frame.matched_point3D_ids = ret['matched_point3D_ids'] + q_frame.matched_sids = ret['matched_sids'] + q_frame.matched_inliers = np.array(ret['inliers']) + q_frame.matched_order = ret['order'] + + if q_frame.image_inlier_tmp is not None: + q_frame.image_inlier = deepcopy(q_frame.image_inlier_tmp) + if q_frame.image_matching_tmp is not None: + q_frame.image_matching = deepcopy(q_frame.image_matching_tmp) + if q_frame.reference_frame_name_tmp is not None: + q_frame.reference_frame_name = q_frame.reference_frame_name_tmp + + # inlier_mask = np.array(ret['inliers']) + # q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] + # q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] + # q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] + # q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] + # q_frame.matched_sids = ret['matched_sids'][inlier_mask] + # q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] + + # print('update_query_frame: ', q_frame.matched_keypoint_ids.shape, q_frame.matched_keypoints.shape, + # q_frame.matched_xyzs.shape, q_frame.matched_xyzs.shape, np.sum(q_frame.matched_inliers)) + + def process_segmentations(self, segs, topk=10): + pred_values, pred_ids = torch.topk(segs, k=segs.shape[-1], largest=True, dim=-1) # [N, C] + pred_values = pred_values.numpy() + pred_ids = pred_ids.numpy() + + out = [] + used_sids = [] + for k in range(segs.shape[-1]): + values_k = pred_values[:, k] + ids_k = pred_ids[:, k] + uids = np.unique(ids_k) + + out_k = [] + for sid in uids: + if sid == 0: + continue + if sid in used_sids: + continue + used_sids.append(sid) + ids = np.where(ids_k == sid)[0] + score = np.mean(values_k[ids]) + # score = np.median(values_k[ids]) + # score = 100 - k + # out_k.append((ids.shape[0], sid - 1, ids, score)) + out_k.append((ids.shape[0], sid, ids, score)) + + out_k = sorted(out_k, key=lambda item: item[0], reverse=True) + for v in out_k: + out.append((v[1], v[2], v[3])) # [sid, ids, score] + if len(out) >= topk: + return out + return out diff --git a/imcui/third_party/pram/localization/point3d.py b/imcui/third_party/pram/localization/point3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1babf427759c5f588f44023e9e1bf2648a073b --- /dev/null +++ b/imcui/third_party/pram/localization/point3d.py @@ -0,0 +1,21 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> point3d +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 10:13 +==================================================''' +import numpy as np + + +class Point3D: + def __init__(self, id: int, xyz: np.ndarray, error: float, refframe_id: int, seg_id: int = None, + descriptor: np.ndarray = None, rgb: np.ndarray = None, frame_ids: np.ndarray = None): + self.id = id + self.xyz = xyz + self.rgb = rgb + self.error = error + self.seg_id = seg_id + self.refframe_id = refframe_id + self.frame_ids = frame_ids + self.descriptor = descriptor diff --git a/imcui/third_party/pram/localization/pose_estimator.py b/imcui/third_party/pram/localization/pose_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..5d28d6001d38cfd5f6f6135c611293ab5e83cf0a --- /dev/null +++ b/imcui/third_party/pram/localization/pose_estimator.py @@ -0,0 +1,612 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> pose_estimation +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 08/02/2024 11:01 +==================================================''' +import torch +import numpy as np +import pycolmap +import cv2 +import os +import time +import os.path as osp +from collections import defaultdict + + +def get_covisibility_frames(frame_id, all_images, points3D, covisibility_frame=50): + observed = all_images[frame_id].point3D_ids + covis = defaultdict(int) + for pid in observed: + if pid == -1: + continue + for img_id in points3D[pid].image_ids: + if img_id != frame_id: + covis[img_id] += 1 + + print('Find {:d} connected frames'.format(len(covis.keys()))) + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= covisibility_frame: + sel_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + ind_top = np.argpartition(covis_num, -covisibility_frame) + ind_top = ind_top[-covisibility_frame:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + sel_covis_ids = [covis_ids[i] for i in ind_top] + + print('Retain {:d} valid connected frames'.format(len(sel_covis_ids))) + return sel_covis_ids + + +def feature_matching(query_data, db_data, matcher): + db_3D_ids = db_data['db_3D_ids'] + if db_3D_ids is None: + with torch.no_grad(): + match_data = { + 'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(), + 'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(), + 'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(), + 'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]), + + 'keypoints1': torch.from_numpy(db_data['keypoints'])[None].float().cuda(), + 'scores1': torch.from_numpy(db_data['scores'])[None].float().cuda(), + 'descriptors1': torch.from_numpy(db_data['descriptors'])[None].float().cuda(), # [B, N, D] + 'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]), + } + matches = matcher(match_data)['matches0'][0].cpu().numpy() + del match_data + else: + masks = (db_3D_ids != -1) + valid_ids = [i for i in range(masks.shape[0]) if masks[i]] + if len(valid_ids) == 0: + return np.zeros(shape=(query_data['keypoints'].shape[0],), dtype=int) - 1 + with torch.no_grad(): + match_data = { + 'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(), + 'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(), + 'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(), + 'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]), + + 'keypoints1': torch.from_numpy(db_data['keypoints'])[masks][None].float().cuda(), + 'scores1': torch.from_numpy(db_data['scores'])[masks][None].float().cuda(), + 'descriptors1': torch.from_numpy(db_data['descriptors'][masks])[None].float().cuda(), + 'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]), + } + matches = matcher(match_data)['matches0'][0].cpu().numpy() + del match_data + + for i in range(matches.shape[0]): + if matches[i] >= 0: + matches[i] = valid_ids[matches[i]] + + return matches + + +def find_2D_3D_matches(query_data, db_id, points3D, feature_file, db_images, matcher, obs_th=0): + kpq = query_data['keypoints'] + db_name = db_images[db_id].name + kpdb = feature_file[db_name]['keypoints'][()] + desc_db = feature_file[db_name]["descriptors"][()] + desc_db = desc_db.transpose() + + # print('db_desc: ', desc_db.shape, query_data['descriptors'].shape) + + points3D_ids = db_images[db_id].point3D_ids + matches = feature_matching(query_data=query_data, + db_data={ + 'keypoints': kpdb, + 'scores': feature_file[db_name]['scores'][()], + 'descriptors': desc_db, + 'db_3D_ids': points3D_ids, + 'image_size': feature_file[db_name]['image_size'][()] + }, + matcher=matcher) + mkpdb = [] + mp3d_ids = [] + q_ids = [] + mkpq = [] + mp3d = [] + valid_matches = [] + for idx in range(matches.shape[0]): + if matches[idx] == -1: + continue + if points3D_ids[matches[idx]] == -1: + continue + id_3D = points3D_ids[matches[idx]] + + # reject 3d points without enough observations + if len(points3D[id_3D].image_ids) < obs_th: + continue + mp3d.append(points3D[id_3D].xyz) + mp3d_ids.append(id_3D) + + mkpq.append(kpq[idx]) + mkpdb.append(kpdb[matches[idx]]) + q_ids.append(idx) + valid_matches.append(matches[idx]) + + mp3d = np.array(mp3d, float).reshape(-1, 3) + mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5 + return mp3d, mkpq, mp3d_ids, q_ids + + +# hfnet, cvpr 2019 +def pose_estimator_hloc(qname, qinfo, db_ids, db_images, points3D, + feature_file, + thresh, + image_dir, + matcher, + log_info=None, + query_img_prefix='', + db_img_prefix=''): + kpq = feature_file[qname]['keypoints'][()] + score_q = feature_file[qname]['scores'][()] + desc_q = feature_file[qname]['descriptors'][()] + desc_q = desc_q.transpose() + imgsize_q = feature_file[qname]['image_size'][()] + query_data = { + 'keypoints': kpq, + 'scores': score_q, + 'descriptors': desc_q, + 'image_size': imgsize_q, + } + + camera_model, width, height, params = qinfo + cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params) + cfg = { + 'model': camera_model, + 'width': width, + 'height': height, + 'params': params, + } + all_mkpts = [] + all_mp3ds = [] + all_points3D_ids = [] + best_db_id = db_ids[0] + best_db_name = db_images[best_db_id].name + + t_start = time.time() + + for cluster_idx, db_id in enumerate(db_ids): + mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches( + query_data=query_data, + db_id=db_id, + points3D=points3D, + feature_file=feature_file, + db_images=db_images, + matcher=matcher, + obs_th=3) + if mp3d.shape[0] > 0: + all_mkpts.append(mkpq) + all_mp3ds.append(mp3d) + all_points3D_ids = all_points3D_ids + mp3d_ids + + if len(all_mkpts) == 0: + print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name) + print(print_text) + if log_info is not None: + log_info = log_info + print_text + '\n' + + qvec = db_images[best_db_id].qvec + tvec = db_images[best_db_id].tvec + + return { + 'qvec': qvec, + 'tvec': tvec, + 'log_info': log_info, + 'qname': qname, + 'dbname': best_db_name, + 'num_inliers': 0, + 'order': -1, + 'keypoints_query': np.array([]), + 'points3D_ids': [], + 'time': time.time() - t_start, + } + + all_mkpts = np.vstack(all_mkpts) + all_mp3ds = np.vstack(all_mp3ds) + + ret = pycolmap.absolute_pose_estimation(all_mkpts, all_mp3ds, cam, + estimation_options={ + "ransac": {"max_error": thresh}}, + refinement_options={}, + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + success = ret['success'] + + if success: + print_text = 'qname: {:s} localization success with {:d}/{:d} inliers'.format(qname, ret['num_inliers'], + all_mp3ds.shape[0]) + print(print_text) + if log_info is not None: + log_info = log_info + print_text + '\n' + + qvec = ret['qvec'] + tvec = ret['tvec'] + ret['cfg'] = cfg + num_inliers = ret['num_inliers'] + inliers = ret['inliers'] + return { + 'qvec': qvec, + 'tvec': tvec, + 'log_info': log_info, + 'qname': qname, + 'dbname': best_db_name, + 'num_inliers': num_inliers, + 'order': -1, + 'keypoints_query': np.array([all_mkpts[i] for i in range(len(inliers)) if inliers[i]]), + 'points3D_ids': [all_points3D_ids[i] for i in range(len(inliers)) if inliers[i]], + 'time': time.time() - t_start, + } + else: + print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name) + print(print_text) + if log_info is not None: + log_info = log_info + print_text + '\n' + + qvec = db_images[best_db_id].qvec + tvec = db_images[best_db_id].tvec + + return { + 'qvec': qvec, + 'tvec': tvec, + 'log_info': log_info, + 'qname': qname, + 'dbname': best_db_name, + 'num_inliers': 0, + 'order': -1, + 'keypoints_query': np.array([]), + 'points3D_ids': [], + 'time': time.time() - t_start, + } + + +def pose_refinement(query_data, + query_cam, feature_file, db_frame_id, db_images, points3D, matcher, + covisibility_frame=50, + obs_th=3, + opt_th=12, + qvec=None, + tvec=None, + log_info='', + **kwargs, + ): + db_ids = get_covisibility_frames(frame_id=db_frame_id, all_images=db_images, points3D=points3D, + covisibility_frame=covisibility_frame) + + mp3d = [] + mkpq = [] + mkpdb = [] + all_3D_ids = [] + all_score_q = [] + kpq = query_data['keypoints'] + for i, db_id in enumerate(db_ids): + db_name = db_images[db_id].name + kpdb = feature_file[db_name]['keypoints'][()] + scores_db = feature_file[db_name]['scores'][()] + imgsize_db = feature_file[db_name]['image_size'][()] + desc_db = feature_file[db_name]["descriptors"][()] + desc_db = desc_db.transpose() + + points3D_ids = db_images[db_id].point3D_ids + if points3D_ids.size == 0: + print("No 3D points in this db image: ", db_name) + continue + + matches = feature_matching(query_data=query_data, + db_data={'keypoints': kpdb, + 'scores': scores_db, + 'descriptors': desc_db, + 'image_size': imgsize_db, + 'db_3D_ids': points3D_ids, + }, + matcher=matcher, + ) + valid = np.where(matches > -1)[0] + valid = valid[points3D_ids[matches[valid]] != -1] + inliers = [] + for idx in valid: + id_3D = points3D_ids[matches[idx]] + if len(points3D[id_3D].image_ids) < obs_th: + continue + + inliers.append(True) + + mp3d.append(points3D[id_3D].xyz) + mkpq.append(kpq[idx]) + mkpdb.append(kpdb[matches[idx]]) + all_3D_ids.append(id_3D) + + mp3d = np.array(mp3d, float).reshape(-1, 3) + mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5 + print_text = 'Get {:d} covisible frames with {:d} matches from cluster optimization'.format(len(db_ids), + mp3d.shape[0]) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + + # cam = pycolmap.Camera(model=cfg['model'], params=cfg['params']) + ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, + query_cam, + estimation_options={ + "ransac": {"max_error": opt_th}}, + refinement_options={}, + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + if not ret['success']: + ret['mkpq'] = mkpq + ret['3D_ids'] = all_3D_ids + ret['db_ids'] = db_ids + ret['score_q'] = all_score_q + ret['log_info'] = log_info + ret['qvec'] = qvec + ret['tvec'] = tvec + ret['inliers'] = [False for i in range(mkpq.shape[0])] + ret['num_inliers'] = 0 + ret['keypoints_query'] = np.array([]) + ret['points3D_ids'] = [] + return ret + + ret_inliers = ret['inliers'] + loc_keypoints_query = np.array([mkpq[i] for i in range(len(ret_inliers)) if ret_inliers[i]]) + loc_points3D_ids = [all_3D_ids[i] for i in range(len(ret_inliers)) if ret_inliers[i]] + + ret['mkpq'] = mkpq + ret['3D_ids'] = all_3D_ids + ret['db_ids'] = db_ids + ret['log_info'] = log_info + ret['keypoints_query'] = loc_keypoints_query + ret['points3D_ids'] = loc_points3D_ids + + return ret + + +# proposed in efficient large-scale localization by global instance recognition, cvpr 2022 +def pose_estimator_iterative(qname, qinfo, db_ids, db_images, points3D, feature_file, thresh, image_dir, + matcher, + inlier_th=50, + log_info=None, + do_covisibility_opt=False, + covisibility_frame=50, + vis_dir=None, + obs_th=0, + opt_th=12, + gt_qvec=None, + gt_tvec=None, + query_img_prefix='', + db_img_prefix='', + ): + print("qname: ", qname) + db_name_to_id = {image.name: i for i, image in db_images.items()} + # q_img = cv2.imread(osp.join(image_dir, query_img_prefix, qname)) + + kpq = feature_file[qname]['keypoints'][()] + score_q = feature_file[qname]['scores'][()] + imgsize_q = feature_file[qname]['image_size'][()] + desc_q = feature_file[qname]['descriptors'][()] + desc_q = desc_q.transpose() # [N D] + query_data = { + 'keypoints': kpq, + 'scores': score_q, + 'descriptors': desc_q, + 'image_size': imgsize_q, + } + camera_model, width, height, params = qinfo + + best_results = { + 'tvec': None, + 'qvec': None, + 'num_inliers': 0, + 'single_num_inliers': 0, + 'db_id': -1, + 'order': -1, + 'qname': qname, + 'optimize': False, + 'dbname': db_images[db_ids[0]].name, + "ret_source": "", + "inliers": [], + 'keypoints_query': np.array([]), + 'points3D_ids': [], + } + + cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params) + + for cluster_idx, db_id in enumerate(db_ids): + db_name = db_images[db_id].name + mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches( + query_data=query_data, + db_id=db_id, + points3D=points3D, + feature_file=feature_file, + db_images=db_images, + matcher=matcher, + obs_th=obs_th) + + if mp3d.shape[0] < 8: + print_text = "qname: {:s} dbname: {:s}({:d}/{:d}) failed because of insufficient 3d points {:d}".format( + qname, + db_name, + cluster_idx + 1, + len(db_ids), + mp3d.shape[0]) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + continue + + ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, cam, + estimation_options={ + "ransac": {"max_error": thresh}}, + refinement_options={}, + ) + + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + if not ret["success"]: + print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed after matching".format(qname, db_name, + cluster_idx + 1, + len(db_ids)) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + continue + + inliers = ret['inliers'] + num_inliers = ret['num_inliers'] + inlier_p3d_ids = [mp3d_ids[i] for i in range(len(inliers)) if inliers[i]] + inlier_mkpq = [mkpq[i] for i in range(len(inliers)) if inliers[i]] + loc_keypoints_query = np.array(inlier_mkpq) + loc_points3D_ids = inlier_p3d_ids + + if ret['num_inliers'] > best_results['num_inliers']: + best_results['qvec'] = ret['qvec'] + best_results['tvec'] = ret['tvec'] + best_results['inlier'] = ret['inliers'] + best_results['num_inliers'] = ret['num_inliers'] + best_results['dbname'] = db_name + best_results['order'] = cluster_idx + 1 + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + if ret['num_inliers'] < inlier_th: + print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed insufficient {:d} inliers".format(qname, + db_name, + cluster_idx + 1, + len(db_ids), + num_inliers, + ) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + continue + + print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) initialization succeed with {:d} inliers".format( + qname, + db_name, + cluster_idx + 1, + len(db_ids), + ret["num_inliers"] + ) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + + if do_covisibility_opt: + ret = pose_refinement(qname=qname, + query_cam=cam, + feature_file=feature_file, + db_frame_id=db_id, + db_images=db_images, + points3D=points3D, + thresh=thresh, + covisibility_frame=covisibility_frame, + matcher=matcher, + obs_th=obs_th, + opt_th=opt_th, + qvec=ret['qvec'], + tvec=ret['tvec'], + log_info='', + image_dir=image_dir, + vis_dir=vis_dir, + gt_qvec=gt_qvec, + gt_tvec=gt_tvec, + ) + + loc_keypoints_query = ret['keypoints_query'] + loc_points3D_ids = ret['points3D_ids'] + + log_info = log_info + ret['log_info'] + print_text = 'Find {:d} inliers after optimization'.format(ret['num_inliers']) + print(print_text) + if log_info is not None: + log_info += (print_text + "\n") + + # localization succeed + qvec = ret['qvec'] + tvec = ret['tvec'] + num_inliers = ret['num_inliers'] + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + best_results['qvec'] = qvec + best_results['tvec'] = tvec + best_results['num_inliers'] = num_inliers + best_results['log_info'] = log_info + + return best_results + + if best_results['num_inliers'] >= 10: # 20 for aachen + qvec = best_results['qvec'] + tvec = best_results['tvec'] + best_dbname = best_results['dbname'] + + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + if do_covisibility_opt: + ret = pose_refinement(qname=qname, + query_cam=cam, + feature_file=feature_file, + db_frame_id=db_name_to_id[best_dbname], + db_images=db_images, + points3D=points3D, + thresh=thresh, + covisibility_frame=covisibility_frame, + matcher=matcher, + obs_th=obs_th, + opt_th=opt_th, + qvec=qvec, + tvec=tvec, + log_info='', + image_dir=image_dir, + vis_dir=vis_dir, + gt_qvec=gt_qvec, + gt_tvec=gt_tvec, + ) + + # localization succeed + qvec = ret['qvec'] + tvec = ret['tvec'] + num_inliers = ret['num_inliers'] + best_results['keypoints_query'] = loc_keypoints_query + best_results['points3D_ids'] = loc_points3D_ids + + best_results['qvec'] = qvec + best_results['tvec'] = tvec + best_results['num_inliers'] = num_inliers + best_results['log_info'] = log_info + + return best_results + + closest = db_images[db_ids[0][0]] + print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, closest.name) + print(print_text) + if log_info is not None: + log_info += (print_text + '\n') + + best_results['qvec'] = closest.qvec + best_results['tvec'] = closest.tvec + best_results['num_inliers'] = -1 + best_results['log_info'] = log_info + + return best_results diff --git a/imcui/third_party/pram/localization/refframe.py b/imcui/third_party/pram/localization/refframe.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eeafd44557ffdfda5829dab00dd5df125148b4 --- /dev/null +++ b/imcui/third_party/pram/localization/refframe.py @@ -0,0 +1,147 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> refframe +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 10:06 +==================================================''' +import numpy as np +from localization.camera import Camera +from colmap_utils.camera_intrinsics import intrinsics_from_camera +from colmap_utils.read_write_model import qvec2rotmat + + +class RefFrame: + def __init__(self, camera: Camera, id: int, qvec: np.ndarray, tvec: np.ndarray, + point3D_ids: np.ndarray = None, keypoints: np.ndarray = None, + name: str = None, scene_name: str = None): + self.camera = camera + self.id = id + self.qvec = qvec + self.tvec = tvec + self.name = name + self.scene_name = scene_name + self.width = camera.width + self.height = camera.height + self.image_size = np.array([self.height, self.width]) + + self.point3D_ids = point3D_ids + self.keypoints = keypoints + self.descriptors = None + self.keypoint_segs = None + self.xyzs = None + + def get_keypoints_by_sid(self, sid: int): + mask = (self.keypoint_segs == sid) + return { + 'point3D_ids': self.point3D_ids[mask], + 'keypoints': self.keypoints[mask][:, :2], + 'descriptors': self.descriptors[mask], + 'scores': self.keypoints[mask][:, 2], + 'xyzs': self.xyzs[mask], + 'camera': self.camera, + } + + valid_p3d_ids = [] + valid_kpts = [] + valid_descs = [] + valid_scores = [] + valid_xyzs = [] + for i, v in enumerate(self.point3D_ids): + if v in point3Ds.keys(): + p3d = point3Ds[v] + if p3d.seg_id == sid: + valid_kpts.append(self.keypoints[i]) + valid_p3d_ids.append(v) + valid_xyzs.append(p3d.xyz) + valid_descs.append(p3d.descriptor) + valid_scores.append(p3d.error) + return { + 'point3D_ids': np.array(valid_p3d_ids), + 'keypoints': np.array(valid_kpts), + 'descriptors': np.array(valid_descs), + 'scores': np.array(valid_scores), + 'xyzs': np.array(valid_xyzs), + } + + def get_keypoints(self): + return { + 'point3D_ids': self.point3D_ids, + 'keypoints': self.keypoints[:, :2], + 'descriptors': self.descriptors, + 'scores': self.keypoints[:, 2], + 'xyzs': self.xyzs, + 'camera': self.camera, + } + + valid_p3d_ids = [] + valid_kpts = [] + valid_descs = [] + valid_scores = [] + valid_xyzs = [] + for i, v in enumerate(self.point3D_ids): + if v in point3Ds.keys(): + p3d = point3Ds[v] + valid_kpts.append(self.keypoints[i]) + valid_p3d_ids.append(v) + valid_xyzs.append(p3d.xyz) + valid_descs.append(p3d.descriptor) + valid_scores.append(p3d.error) + return { + 'points3D_ids': np.array(valid_p3d_ids), + 'keypoints': np.array(valid_kpts), + 'descriptors': np.array(valid_descs), + 'scores': 1 / np.clip(np.array(valid_scores) * 5, a_min=1., a_max=20.), + 'xyzs': np.array(valid_xyzs), + 'camera': self.camera, + } + + def associate_keypoints_with_point3Ds(self, point3Ds: dict): + xyzs = [] + descs = [] + scores = [] + p3d_ids = [] + kpt_sids = [] + for i, v in enumerate(self.point3D_ids): + if v in point3Ds.keys(): + p3d = point3Ds[v] + p3d_ids.append(v) + xyzs.append(p3d.xyz) + descs.append(p3d.descriptor) + scores.append(p3d.error) + + kpt_sids.append(p3d.seg_id) + + xyzs = np.array(xyzs) + if xyzs.shape[0] == 0: + return False + + descs = np.array(descs) + scores = 1 / np.clip(np.array(scores) * 5, a_min=1., a_max=20.) + p3d_ids = np.array(p3d_ids) + uvs = self.project(xyzs=xyzs) + self.keypoints = np.hstack([uvs, scores.reshape(-1, 1)]) + self.descriptors = descs + self.point3D_ids = p3d_ids + self.xyzs = xyzs + self.keypoint_segs = np.array(kpt_sids) + + return True + + def project(self, xyzs): + ''' + :param xyzs: [N, 3] + :return: + ''' + K = intrinsics_from_camera(camera_model=self.camera.model, params=self.camera.params) # [3, 3] + Rcw = qvec2rotmat(self.qvec) + tcw = self.tvec.reshape(3, 1) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3:] = tcw + xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1))]) # [N 4] + + xyzs_cam = Tcw @ xyzs_homo.transpose() # [4, N] + uvs = K @ xyzs_cam[:3, :] # [3, N] + uvs[:2, :] = uvs[:2, :] / uvs[2, :] + return uvs[:2, :].transpose() diff --git a/imcui/third_party/pram/localization/singlemap3d.py b/imcui/third_party/pram/localization/singlemap3d.py new file mode 100644 index 0000000000000000000000000000000000000000..77fc0ef2c78321044bb8f8f2952ccb278ea28d8f --- /dev/null +++ b/imcui/third_party/pram/localization/singlemap3d.py @@ -0,0 +1,532 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> map3d +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 04/03/2024 10:25 +==================================================''' +import numpy as np +from collections import defaultdict +import os.path as osp +import pycolmap +import logging +import time + +import torch + +from localization.refframe import RefFrame +from localization.frame import Frame +from localization.point3d import Point3D +from colmap_utils.read_write_model import qvec2rotmat, read_model, read_compressed_model +from localization.utils import read_gt_pose + + +class SingleMap3D: + def __init__(self, config, matcher, with_compress=False, start_sid: int = 0): + self.config = config + self.matcher = matcher + self.image_path_prefix = self.config['image_path_prefix'] + self.start_sid = start_sid # for a dataset with multiple scenes + if not with_compress: + cameras, images, p3ds = read_model( + path=osp.join(config['landmark_path'], 'model'), ext='.bin') + p3d_descs = np.load(osp.join(config['landmark_path'], 'point3D_desc.npy'), + allow_pickle=True)[()] + else: + cameras, images, p3ds = read_compressed_model( + path=osp.join(config['landmark_path'], 'compress_model_{:s}'.format(config['cluster_method'])), + ext='.bin') + p3d_descs = np.load(osp.join(config['landmark_path'], 'compress_model_{:s}/point3D_desc.npy'.format( + config['cluster_method'])), allow_pickle=True)[()] + + print('Load {} cameras {} images {} 3D points'.format(len(cameras), len(images), len(p3d_descs))) + + seg_data = np.load( + osp.join(config['landmark_path'], 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'], + config['cluster_mode'], + config['cluster_method'])), + allow_pickle=True)[()] + + p3d_id = seg_data['id'] + seg_id = seg_data['label'] + p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + seg_p3d = {} + for k in p3d_seg.keys(): + sid = p3d_seg[k] + if sid in seg_p3d.keys(): + seg_p3d[sid].append(k) + else: + seg_p3d[sid] = [k] + + print('Load {} segments and {} 3d points'.format(len(seg_p3d.keys()), len(p3d_seg.keys()))) + seg_vrf = np.load( + osp.join(config['landmark_path'], 'point3D_vrf_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'], + config['cluster_mode'], + config['cluster_method'])), + allow_pickle=True)[()] + + # construct 3D map + self.initialize_point3Ds(p3ds=p3ds, p3d_descs=p3d_descs, p3d_seg=p3d_seg) + self.initialize_ref_frames(cameras=cameras, images=images) + + all_vrf_frame_ids = [] + self.seg_ref_frame_ids = {} + for sid in seg_vrf.keys(): + self.seg_ref_frame_ids[sid] = [] + for vi in seg_vrf[sid].keys(): + vrf_frame_id = seg_vrf[sid][vi]['image_id'] + self.seg_ref_frame_ids[sid].append(vrf_frame_id) + if with_compress and vrf_frame_id in self.reference_frames.keys(): + self.reference_frames[vrf_frame_id].point3D_ids = seg_vrf[sid][vi]['original_points3d'] + + all_vrf_frame_ids.extend(self.seg_ref_frame_ids[sid]) + + if with_compress: + all_ref_ids = list(self.reference_frames.keys()) + for fid in all_ref_ids: + valid = self.reference_frames[fid].associate_keypoints_with_point3Ds(point3Ds=self.point3Ds) + if not valid: + del self.reference_frames[fid] + + all_vrf_frame_ids = np.unique(all_vrf_frame_ids) + all_vrf_frame_ids = [v for v in all_vrf_frame_ids if v in self.reference_frames.keys()] + self.build_covisibility_graph(frame_ids=all_vrf_frame_ids, n_frame=config['localization'][ + 'covisibility_frame']) # build covisible frames for vrf frames only + + logging.info( + f'Construct {len(self.reference_frames.keys())} ref frames and {len(self.point3Ds.keys())} 3d points') + + self.gt_poses = {} + if config['gt_pose_path'] is not None: + gt_pose_path = osp.join(config['dataset_path'], config['gt_pose_path']) + self.read_gt_pose(path=gt_pose_path) + + def read_gt_pose(self, path, prefix=''): + self.gt_poses = read_gt_pose(path=path) + print('Load {} gt poses'.format(len(self.gt_poses.keys()))) + + def initialize_point3Ds(self, p3ds, p3d_descs, p3d_seg): + self.point3Ds = {} + for id in p3ds.keys(): + if id not in p3d_seg.keys(): + continue + self.point3Ds[id] = Point3D(id=id, xyz=p3ds[id].xyz, error=p3ds[id].error, + refframe_id=-1, rgb=p3ds[id].rgb, + descriptor=p3d_descs[id], seg_id=p3d_seg[id], + frame_ids=p3ds[id].image_ids) + + def initialize_ref_frames(self, cameras, images): + self.reference_frames = {} + for id in images.keys(): + im = images[id] + cam = cameras[im.camera_id] + self.reference_frames[id] = RefFrame(camera=cam, id=id, qvec=im.qvec, tvec=im.tvec, + point3D_ids=im.point3D_ids, + keypoints=im.xys, name=im.name) + + def localize_with_ref_frame(self, q_frame: Frame, q_kpt_ids: np.ndarray, sid, semantic_matching=False): + ref_frame_id = self.seg_ref_frame_ids[sid][0] + ref_frame = self.reference_frames[ref_frame_id] + if semantic_matching and sid > 0: + ref_data = ref_frame.get_keypoints_by_sid(sid=sid) + else: + ref_data = ref_frame.get_keypoints() + + q_descs = q_frame.descriptors[q_kpt_ids] + q_kpts = q_frame.keypoints[q_kpt_ids, :2] + q_scores = q_frame.keypoints[q_kpt_ids, 2] + + xyzs = ref_data['xyzs'] + point3D_ids = ref_data['point3D_ids'] + ref_sids = np.array([self.point3Ds[v].seg_id for v in point3D_ids]) + with torch.no_grad(): + indices0 = self.matcher({ + 'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(q_scores)[None].cuda().float(), + 'image_shape0': (1, 3, q_frame.camera.width, q_frame.camera.height), + + 'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(), + 'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(), + 'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(), + 'image_shape1': (1, 3, ref_frame.camera.width, ref_frame.camera.height), + } + )['matches0'][0].cpu().numpy() + + valid = indices0 >= 0 + mkpts = q_kpts[valid] + mkpt_ids = q_kpt_ids[valid] + mxyzs = xyzs[indices0[valid]] + mpoint3D_ids = point3D_ids[indices0[valid]] + matched_sids = ref_sids[indices0[valid]] + matched_ref_keypoints = ref_data['keypoints'][indices0[valid]] + + # print('mkpts: ', mkpts.shape, mxyzs.shape, np.sum(indices0 >= 0)) + # cfg = q_frame.camera._asdict() + # q_cam = pycolmap.Camera(model=q_frame.camera.model, ) + # config = {"estimation": {"ransac": {"max_error": ransac_thresh}}, **(config or {})} + ret = pycolmap.absolute_pose_estimation(mkpts + 0.5, + mxyzs, + q_frame.camera, + estimation_options={ + "ransac": {"max_error": self.config['localization']['threshold']}}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'] + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + ret['matched_keypoints'] = mkpts + ret['matched_keypoint_ids'] = mkpt_ids + ret['matched_xyzs'] = mxyzs + ret['reference_frame_id'] = ref_frame_id + ret['matched_point3D_ids'] = mpoint3D_ids + ret['matched_sids'] = matched_sids + ret['matched_ref_keypoints'] = matched_ref_keypoints + + if not ret['success']: + ret['num_inliers'] = 0 + ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool) + return ret + + def match(self, query_data, ref_data): + q_descs = query_data['descriptors'] + q_kpts = query_data['keypoints'] + q_scores = query_data['scores'] + xyzs = ref_data['xyzs'] + points3D_ids = ref_data['point3D_ids'] + with torch.no_grad(): + indices0 = self.matcher({ + 'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(q_scores)[None].cuda().float(), + 'image_shape0': (1, 3, query_data['camera'].width, query_data['camera'].height), + + 'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(), + 'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(), + 'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(), + 'image_shape1': (1, 3, ref_data['camera'].width, ref_data['camera'].height), + } + )['matches0'][0].cpu().numpy() + + valid = indices0 >= 0 + mkpts = q_kpts[valid] + mkpt_ids = np.where(valid)[0] + mxyzs = xyzs[indices0[valid]] + mpoints3D_ids = points3D_ids[indices0[valid]] + + return { + 'matched_keypoints': mkpts, + 'matched_xyzs': mxyzs, + 'matched_point3D_ids': mpoints3D_ids, + 'matched_keypoint_ids': mkpt_ids, + } + + def build_covisibility_graph(self, frame_ids: list = None, n_frame: int = 20): + def find_covisible_frames(frame_id): + observed = self.reference_frames[frame_id].point3D_ids + covis = defaultdict(int) + for pid in observed: + if pid == -1: + continue + if pid not in self.point3Ds.keys(): + continue + for img_id in self.point3Ds[pid].frame_ids: + covis[img_id] += 1 + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= n_frame: + sel_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + ind_top = np.argpartition(covis_num, -n_frame) + ind_top = ind_top[-n_frame:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + sel_covis_ids = [covis_ids[i] for i in ind_top] + + return sel_covis_ids + + if frame_ids is None: + frame_ids = list(self.referece_frames.keys()) + + self.covisible_graph = defaultdict() + for frame_id in frame_ids: + self.covisible_graph[frame_id] = find_covisible_frames(frame_id=frame_id) + + def refine_pose(self, q_frame: Frame, refinement_method='matching'): + if refinement_method == 'matching': + return self.refine_pose_by_matching(q_frame=q_frame) + elif refinement_method == 'projection': + return self.refine_pose_by_projection(q_frame=q_frame) + else: + raise NotImplementedError + + def refine_pose_by_matching(self, q_frame): + ref_frame_id = q_frame.reference_frame_id + db_ids = self.covisible_graph[ref_frame_id] + print('Find {} covisible frames'.format(len(db_ids))) + loc_success = q_frame.tracking_status + if loc_success and ref_frame_id in db_ids: + init_kpts = q_frame.matched_keypoints + init_kpt_ids = q_frame.matched_keypoint_ids + init_point3D_ids = q_frame.matched_point3D_ids + init_xyzs = np.array([self.point3Ds[v].xyz for v in init_point3D_ids]).reshape(-1, 3) + list(db_ids).remove(ref_frame_id) + else: + init_kpts = None + init_xyzs = None + init_point3D_ids = None + + matched_xyzs = [] + matched_kpts = [] + matched_point3D_ids = [] + matched_kpt_ids = [] + for idx, frame_id in enumerate(db_ids): + ref_data = self.reference_frames[frame_id].get_keypoints() + match_out = self.match(query_data={ + 'keypoints': q_frame.keypoints[:, :2], + 'scores': q_frame.keypoints[:, 2], + 'descriptors': q_frame.descriptors, + 'camera': q_frame.camera, }, + ref_data=ref_data) + if match_out['matched_keypoints'].shape[0] > 0: + matched_kpts.append(match_out['matched_keypoints']) + matched_xyzs.append(match_out['matched_xyzs']) + matched_point3D_ids.append(match_out['matched_point3D_ids']) + matched_kpt_ids.append(match_out['matched_keypoint_ids']) + if len(matched_kpts) > 1: + matched_kpts = np.vstack(matched_kpts) + matched_xyzs = np.vstack(matched_xyzs).reshape(-1, 3) + matched_point3D_ids = np.hstack(matched_point3D_ids) + matched_kpt_ids = np.hstack(matched_kpt_ids) + else: + matched_kpts = matched_kpts[0] + matched_xyzs = matched_xyzs[0] + matched_point3D_ids = matched_point3D_ids[0] + matched_kpt_ids = matched_kpt_ids[0] + if init_kpts is not None and init_kpts.shape[0] > 0: + matched_kpts = np.vstack([matched_kpts, init_kpts]) + matched_xyzs = np.vstack([matched_xyzs, init_xyzs]) + matched_point3D_ids = np.hstack([matched_point3D_ids, init_point3D_ids]) + matched_kpt_ids = np.hstack([matched_kpt_ids, init_kpt_ids]) + + matched_sids = np.array([self.point3Ds[v].seg_id for v in matched_point3D_ids]) + + print_text = 'Refinement by matching. Get {:d} covisible frames with {:d} matches for optimization'.format( + len(db_ids), matched_xyzs.shape[0]) + print(print_text) + + t_start = time.time() + ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, + matched_xyzs, + q_frame.camera, + estimation_options={ + 'ransac': { + 'max_error': self.config['localization']['threshold'], + 'min_num_trials': 1000, + 'max_num_trials': 10000, + 'confidence': 0.995, + }}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'], + # min_num_trials=1000, max_num_trials=10000, confidence=0.995) + ) + print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start)) + + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + ret['matched_keypoints'] = matched_kpts + ret['matched_keypoint_ids'] = matched_kpt_ids + ret['matched_xyzs'] = matched_xyzs + ret['matched_point3D_ids'] = matched_point3D_ids + ret['matched_sids'] = matched_sids + + if ret['success']: + inlier_mask = np.array(ret['inliers']) + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids[inlier_mask], + candidate_frame_ids=self.covisible_graph.keys()) + else: + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids, + candidate_frame_ids=self.covisible_graph.keys()) + + ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][ + 'covisibility_frame']] + ret['reference_frame_id'] = best_reference_frame_ids[0] + + return ret + + @torch.no_grad() + def refine_pose_by_projection(self, q_frame): + q_Rcw = qvec2rotmat(q_frame.qvec) + q_tcw = q_frame.tvec + q_Tcw = np.eye(4, dtype=float) # [4 4] + q_Tcw[:3, :3] = q_Rcw + q_Tcw[:3, 3] = q_tcw + cam = q_frame.camera + imw = cam.width + imh = cam.height + K = q_frame.get_intrinsics() # [3, 3] + reference_frame_id = q_frame.reference_frame_id + covis_frame_ids = self.covisible_graph[reference_frame_id] + if reference_frame_id not in covis_frame_ids: + covis_frame_ids.append(reference_frame_id) + all_point3D_ids = [] + + for frame_id in covis_frame_ids: + all_point3D_ids.extend(list(self.reference_frames[frame_id].point3D_ids)) + + all_point3D_ids = np.unique(all_point3D_ids) + all_xyzs = [] + all_descs = [] + all_sids = [] + for pid in all_point3D_ids: + all_xyzs.append(self.point3Ds[pid].xyz) + all_descs.append(self.point3Ds[pid].descriptor) + all_sids.append(self.point3Ds[pid].seg_id) + + all_xyzs = np.array(all_xyzs) # [N 3] + all_descs = np.array(all_descs) # [N 3] + all_point3D_ids = np.array(all_point3D_ids) + all_sids = np.array(all_sids) + + # move to gpu (distortion is not included) + # proj_uv = pycolmap.camera.img_from_cam( + # np.array([1, 1, 1]).reshape(1, 3), + # ) + all_xyzs_cuda = torch.from_numpy(all_xyzs).cuda() + ones = torch.ones(size=(all_xyzs_cuda.shape[0], 1), dtype=all_xyzs_cuda.dtype).cuda() + all_xyzs_cuda_homo = torch.cat([all_xyzs_cuda, ones], dim=1) # [N 4] + K_cuda = torch.from_numpy(K).cuda() + proj_uvs = K_cuda @ (torch.from_numpy(q_Tcw).cuda() @ all_xyzs_cuda_homo.t())[:3, :] # [3, N] + proj_uvs[0] /= proj_uvs[2] + proj_uvs[1] /= proj_uvs[2] + mask = (proj_uvs[2] > 0) * (proj_uvs[2] < 100) * (proj_uvs[0] >= 0) * (proj_uvs[0] < imw) * ( + proj_uvs[1] >= 0) * (proj_uvs[1] < imh) + + proj_uvs = proj_uvs[:, mask] + + print('Projection: out of range {:d}/{:d}'.format(all_xyzs_cuda.shape[0], proj_uvs.shape[1])) + + mxyzs = all_xyzs[mask.cpu().numpy()] + mpoint3D_ids = all_point3D_ids[mask.cpu().numpy()] + msids = all_sids[mask.cpu().numpy()] + + q_kpts_cuda = torch.from_numpy(q_frame.keypoints[:, :2]).cuda() + proj_error = q_kpts_cuda[..., None] - proj_uvs[:2][None] + proj_error = torch.sqrt(torch.sum(proj_error ** 2, dim=1)) # [M N] + out_of_range_mask = (proj_error >= 2 * self.config['localization']['threshold']) + + q_descs_cuda = torch.from_numpy(q_frame.descriptors).cuda().float() # [M D] + all_descs_cuda = torch.from_numpy(all_descs).cuda().float()[mask] # [N D] + desc_dist = torch.sqrt(2 - 2 * q_descs_cuda @ all_descs_cuda.t() + 1e-6) + desc_dist[out_of_range_mask] = desc_dist[out_of_range_mask] + 100 + dists, ids = torch.topk(desc_dist, k=2, largest=False, dim=1) + # apply nn ratio + ratios = dists[:, 0] / dists[:, 1] # smaller, better + ratio_mask = (ratios <= 0.995) * (dists[:, 0] < 100) + ratio_mask = ratio_mask.cpu().numpy() + ids = ids.cpu().numpy()[ratio_mask, 0] + + ratio_num = torch.sum(ratios <= 0.995) + proj_num = torch.sum(dists[:, 0] < 100) + + print('Projection: after ratio {:d}/{:d}, ratio {:d}, proj {:d}'.format(q_kpts_cuda.shape[0], + np.sum(ratio_mask), + ratio_num, proj_num)) + + mkpts = q_frame.keypoints[ratio_mask] + mkpt_ids = np.where(ratio_mask)[0] + mxyzs = mxyzs[ids] + mpoint3D_ids = mpoint3D_ids[ids] + msids = msids[ids] + print('projection: ', mkpts.shape, mkpt_ids.shape, mxyzs.shape, mpoint3D_ids.shape, msids.shape) + + t_start = time.time() + ret = pycolmap.absolute_pose_estimation(mkpts[:, :2] + 0.5, mxyzs, q_frame.camera, + estimation_options={ + "ransac": {"max_error": self.config['localization']['threshold']}}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'] + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + # inlier_mask = np.ones(shape=(mkpts.shape[0],), dtype=bool).tolist() + # ret = pycolmap.pose_refinement(q_frame.tvec, q_frame.qvec, mkpts[:, :2] + 0.5, mxyzs, inlier_mask, cfg) + # ret['num_inliers'] = np.sum(inlier_mask).astype(int) + # ret['inliers'] = np.array(inlier_mask) + + print_text = 'Refinement by projection. Get {:d} inliers of {:d} matches for optimization'.format( + ret['num_inliers'], mxyzs.shape[0]) + print(print_text) + print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start)) + + ret['matched_keypoints'] = mkpts + ret['matched_xyzs'] = mxyzs + ret['matched_point3D_ids'] = mpoint3D_ids + ret['matched_sids'] = msids + ret['matched_keypoint_ids'] = mkpt_ids + + if ret['success']: + inlier_mask = np.array(ret['inliers']) + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids[inlier_mask], + candidate_frame_ids=self.covisible_graph.keys()) + else: + best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids, + candidate_frame_ids=self.covisible_graph.keys()) + + ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][ + 'covisibility_frame']] + ret['reference_frame_id'] = best_reference_frame_ids[0] + + if not ret['success']: + ret['num_inliers'] = 0 + ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool) + + return ret + + def find_reference_frames(self, matched_point3D_ids, candidate_frame_ids=None): + covis_frames = defaultdict(int) + for pid in matched_point3D_ids: + for im_id in self.point3Ds[pid].frame_ids: + if candidate_frame_ids is not None and im_id in candidate_frame_ids: + covis_frames[im_id] += 1 + + covis_ids = np.array(list(covis_frames.keys())) + covis_num = np.array([covis_frames[i] for i in covis_ids]) + sorted_idxes = np.argsort(covis_num)[::-1] # larger to small + sorted_frame_ids = covis_ids[sorted_idxes] + return sorted_frame_ids + + def check_semantic_consistency(self, q_frame: Frame, sid, overlap_ratio=0.5): + ref_frame_id = self.seg_ref_frame_ids[sid][0] + ref_frame = self.reference_frames[ref_frame_id] + + q_sids = q_frame.seg_ids + ref_sids = np.array([self.point3Ds[v].seg_id for v in ref_frame.point3D_ids]) + self.start_sid + overlap_sids = np.intersect1d(q_sids, ref_sids) + + overlap_num1 = 0 + overlap_num2 = 0 + for sid in overlap_sids: + overlap_num1 += np.sum(q_sids == sid) + overlap_num2 += np.sum(ref_sids == sid) + + ratio1 = overlap_num1 / q_sids.shape[0] + ratio2 = overlap_num2 / ref_sids.shape[0] + + # print('semantic_check: ', overlap_sids, overlap_num1, ratio1, overlap_num2, ratio2) + + return min(ratio1, ratio2) >= overlap_ratio diff --git a/imcui/third_party/pram/localization/tracker.py b/imcui/third_party/pram/localization/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..a401fea82c2372cfdf301ab2d2fb34981facf4fe --- /dev/null +++ b/imcui/third_party/pram/localization/tracker.py @@ -0,0 +1,338 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> tracker +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/02/2024 16:58 +==================================================''' +import time +import cv2 +import numpy as np +import torch +import pycolmap +from localization.frame import Frame +from localization.base_model import dynamic_load +import localization.matchers as matchers +from localization.match_features_batch import confs as matcher_confs +from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches +from tools.common import resize_img + + +class Tracker: + def __init__(self, locMap, matcher, config): + self.locMap = locMap + self.matcher = matcher + self.config = config + self.loc_config = config['localization'] + + self.lost = True + + self.curr_frame = None + self.last_frame = None + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Model = dynamic_load(matchers, 'nearest_neighbor') + self.nn_matcher = Model(matcher_confs['NNM']['model']).eval().to(device) + + def run(self, frame: Frame): + print('Start tracking...') + show = self.config['localization']['show'] + self.curr_frame = frame + ref_img = self.last_frame.image + curr_img = self.curr_frame.image + q_kpts = frame.keypoints + + t_start = time.time() + ret = self.track_last_frame(curr_frame=self.curr_frame, last_frame=self.last_frame) + self.curr_frame.time_loc = self.curr_frame.time_loc + time.time() - t_start + + if show: + curr_matched_kpts = ret['matched_keypoints'] + ref_matched_kpts = ret['matched_ref_keypoints'] + img_loc_matching = plot_matches(img1=curr_img, img2=ref_img, + pts1=curr_matched_kpts, + pts2=ref_matched_kpts, + inliers=np.array([True for i in range(curr_matched_kpts.shape[0])]), + radius=9, line_thickness=3) + self.curr_frame.image_matching = img_loc_matching + + q_ref_img_matching = resize_img(img_loc_matching, nh=512) + + if not ret['success']: + show_text = 'Tracking FAILED!' + img_inlier = vis_inlier(img=curr_img, kpts=curr_matched_kpts, + inliers=[False for i in range(curr_matched_kpts.shape[0])], radius=9 + 2, + thickness=2) + q_img_inlier = cv2.putText(img=img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + return False + + ret['matched_scene_name'] = self.last_frame.scene_name + success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) + + if not success: + return False + + if ret['num_inliers'] < 256: + # refinement is necessary for tracking last frame + t_start = time.time() + ret = self.locMap.sub_maps[self.last_frame.matched_scene_name].refine_pose(self.curr_frame, + refinement_method= + self.loc_config[ + 'refinement_method']) + self.curr_frame.time_ref = self.curr_frame.time_ref + time.time() - t_start + ret['matched_scene_name'] = self.last_frame.scene_name + success = self.verify_and_update(q_frame=self.curr_frame, ret=ret) + + if show: + q_err, t_err = self.curr_frame.compute_pose_error() + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + show_text = 'Tracking, k/m/i: {:d}/{:d}/{:d}'.format(q_kpts.shape[0], num_matches, num_inliers) + q_img_inlier = vis_inlier(img=curr_img, kpts=ret['matched_keypoints'], inliers=ret['inliers'], + radius=9 + 2, thickness=2) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) + q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA) + self.curr_frame.image_inlier = q_img_inlier + + q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) + + cv2.imshow('loc', q_img_loc) + key = cv2.waitKey(self.loc_config['show_time']) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + + self.lost = success + return success + + def verify_and_update(self, q_frame: Frame, ret: dict): + num_matches = ret['matched_keypoints'].shape[0] + num_inliers = ret['num_inliers'] + + q_frame.qvec = ret['qvec'] + q_frame.tvec = ret['tvec'] + + q_err, t_err = q_frame.compute_pose_error() + + if num_inliers < self.loc_config['min_inliers']: + print_text = 'Failed due to insufficient {:d} inliers, q_err: {:.2f}, t_err: {:.2f}'.format( + ret['num_inliers'], q_err, t_err) + print(print_text) + q_frame.tracking_status = False + q_frame.clear_localization_track() + return False + else: + print_text = 'Succeed! Find {}/{} 2D-3D inliers,q_err: {:.2f}, t_err: {:.2f}'.format( + num_inliers, num_matches, q_err, t_err) + print(print_text) + q_frame.tracking_status = True + + self.update_current_frame(curr_frame=q_frame, ret=ret) + return True + + def update_current_frame(self, curr_frame: Frame, ret: dict): + curr_frame.qvec = ret['qvec'] + curr_frame.tvec = ret['tvec'] + + curr_frame.matched_scene_name = ret['matched_scene_name'] + curr_frame.reference_frame_id = ret['reference_frame_id'] + inliers = np.array(ret['inliers']) + + curr_frame.matched_keypoints = ret['matched_keypoints'][inliers] + curr_frame.matched_xyzs = ret['matched_xyzs'][inliers] + curr_frame.matched_point3D_ids = ret['matched_point3D_ids'][inliers] + curr_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inliers] + curr_frame.matched_sids = ret['matched_sids'][inliers] + + def track_last_frame(self, curr_frame: Frame, last_frame: Frame): + curr_kpts = curr_frame.keypoints[:, :2] + curr_scores = curr_frame.keypoints[:, 2] + curr_descs = curr_frame.descriptors + curr_kpt_ids = np.arange(curr_kpts.shape[0]) + + last_kpts = last_frame.keypoints[:, :2] + last_scores = last_frame.keypoints[:, 2] + last_descs = last_frame.descriptors + last_xyzs = last_frame.xyzs + last_point3D_ids = last_frame.point3D_ids + last_sids = last_frame.seg_ids + + # ''' + indices = self.matcher({ + 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), + 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), + + 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), + 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), + 'scores1': torch.from_numpy(last_scores)[None].cuda().float(), + 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), + })['matches0'][0].cpu().numpy() + ''' + + indices = self.nn_matcher({ + 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], + 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], + })['matches0'][0].cpu().numpy() + ''' + + valid = (indices >= 0) + + matched_point3D_ids = last_point3D_ids[indices[valid]] + point3D_mask = (matched_point3D_ids >= 0) + matched_point3D_ids = matched_point3D_ids[point3D_mask] + matched_sids = last_sids[indices[valid]][point3D_mask] + + matched_kpts = curr_kpts[valid][point3D_mask] + matched_kpt_ids = curr_kpt_ids[valid][point3D_mask] + matched_xyzs = last_xyzs[indices[valid]][point3D_mask] + matched_last_kpts = last_kpts[indices[valid]][point3D_mask] + + print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], + last_kpts.shape[0])) + + # print('tracking: ', matched_kpts.shape, matched_xyzs.shape) + ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, + curr_frame.camera, + estimation_options={ + "ransac": {"max_error": self.config['localization']['threshold']}}, + refinement_options={}, + # max_error_px=self.config['localization']['threshold'] + ) + if ret is None: + ret = {'success': False, } + else: + ret['success'] = True + ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]] + ret['tvec'] = ret['cam_from_world'].translation + + ret['matched_keypoints'] = matched_kpts + ret['matched_keypoint_ids'] = matched_kpt_ids + ret['matched_ref_keypoints'] = matched_last_kpts + ret['matched_xyzs'] = matched_xyzs + ret['matched_point3D_ids'] = matched_point3D_ids + ret['matched_sids'] = matched_sids + ret['reference_frame_id'] = last_frame.reference_frame_id + ret['matched_scene_name'] = last_frame.matched_scene_name + return ret + + def track_last_frame_fast(self, curr_frame: Frame, last_frame: Frame): + curr_kpts = curr_frame.keypoints[:, :2] + curr_scores = curr_frame.keypoints[:, 2] + curr_descs = curr_frame.descriptors + curr_kpt_ids = np.arange(curr_kpts.shape[0]) + + last_point3D_ids = last_frame.point3D_ids + point3D_mask = (last_point3D_ids >= 0) + last_kpts = last_frame.keypoints[:, :2][point3D_mask] + last_scores = last_frame.keypoints[:, 2][point3D_mask] + last_descs = last_frame.descriptors[point3D_mask] + last_xyzs = last_frame.xyzs[point3D_mask] + last_sids = last_frame.seg_ids[point3D_mask] + + minx = np.min(last_kpts[:, 0]) + maxx = np.max(last_kpts[:, 0]) + miny = np.min(last_kpts[:, 1]) + maxy = np.max(last_kpts[:, 1]) + curr_mask = (curr_kpts[:, 0] >= minx) * (curr_kpts[:, 0] <= maxx) * (curr_kpts[:, 1] >= miny) * ( + curr_kpts[:, 1] <= maxy) + + curr_kpts = curr_kpts[curr_mask] + curr_scores = curr_scores[curr_mask] + curr_descs = curr_descs[curr_mask] + curr_kpt_ids = curr_kpt_ids[curr_mask] + # ''' + indices = self.matcher({ + 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(), + 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(), + 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(), + 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height), + + 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(), + 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(), + 'scores1': torch.from_numpy(last_scores)[None].cuda().float(), + 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height), + })['matches0'][0].cpu().numpy() + ''' + + indices = self.nn_matcher({ + 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None], + 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None], + })['matches0'][0].cpu().numpy() + ''' + + valid = (indices >= 0) + + matched_point3D_ids = last_point3D_ids[indices[valid]] + matched_sids = last_sids[indices[valid]] + + matched_kpts = curr_kpts[valid] + matched_kpt_ids = curr_kpt_ids[valid] + matched_xyzs = last_xyzs[indices[valid]] + matched_last_kpts = last_kpts[indices[valid]] + + print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0], + last_kpts.shape[0])) + + # print('tracking: ', matched_kpts.shape, matched_xyzs.shape) + ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs, + curr_frame.camera._asdict(), + max_error_px=self.config['localization']['threshold']) + + ret['matched_keypoints'] = matched_kpts + ret['matched_keypoint_ids'] = matched_kpt_ids + ret['matched_ref_keypoints'] = matched_last_kpts + ret['matched_xyzs'] = matched_xyzs + ret['matched_point3D_ids'] = matched_point3D_ids + ret['matched_sids'] = matched_sids + ret['reference_frame_id'] = last_frame.reference_frame_id + ret['matched_scene_name'] = last_frame.matched_scene_name + return ret + + @torch.no_grad() + def match_frame(self, frame: Frame, reference_frame: Frame): + print('match: ', frame.keypoints.shape, reference_frame.keypoints.shape) + matches = self.matcher({ + 'descriptors0': torch.from_numpy(frame.descriptors)[None].cuda().float(), + 'keypoints0': torch.from_numpy(frame.keypoints[:, :2])[None].cuda().float(), + 'scores0': torch.from_numpy(frame.keypoints[:, 2])[None].cuda().float(), + 'image_shape0': (1, 3, frame.image_size[0], frame.image_size[1]), + + # 'descriptors0': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), + # 'keypoints0': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), + # 'scores0': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), + # 'image_shape0': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), + + 'descriptors1': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(), + 'keypoints1': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(), + 'scores1': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(), + 'image_shape1': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]), + + })['matches0'][0].cpu().numpy() + + ids1 = np.arange(matches.shape[0]) + ids2 = matches + ids1 = ids1[matches >= 0] + ids2 = ids2[matches >= 0] + + mask_p3ds = reference_frame.points3d_mask[ids2] + ids1 = ids1[mask_p3ds] + ids2 = ids2[mask_p3ds] + + return ids1, ids2 diff --git a/imcui/third_party/pram/localization/triangulation.py b/imcui/third_party/pram/localization/triangulation.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b885ec4be9c328353af9c0b0aaf136d694556a --- /dev/null +++ b/imcui/third_party/pram/localization/triangulation.py @@ -0,0 +1,317 @@ +# code is from hloc https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/triangulation.py +import argparse +import contextlib +import io +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pycolmap +from tqdm import tqdm + +from colmap_utils.database import COLMAPDatabase +from colmap_utils.geometry import compute_epipolar_errors +from colmap_utils.io import get_keypoints, get_matches +from colmap_utils.parsers import parse_retrieval +import logging + + +class OutputCapture: + def __init__(self, verbose: bool): + self.verbose = verbose + + def __enter__(self): + if not self.verbose: + self.capture = contextlib.redirect_stdout(io.StringIO()) + self.out = self.capture.__enter__() + + def __exit__(self, exc_type, *args): + if not self.verbose: + self.capture.__exit__(exc_type, *args) + if exc_type is not None: + # logger.error("Failed with output:\n%s", self.out.getvalue()) + logging.error("Failed with output:\n%s", self.out.getvalue()) + sys.stdout.flush() + + +def create_db_from_model( + reconstruction: pycolmap.Reconstruction, database_path: Path +) -> Dict[str, int]: + if database_path.exists(): + # logger.warning("The database already exists, deleting it.") + logging.warning("The database already exists, deleting it.") + database_path.unlink() + + db = COLMAPDatabase.connect(database_path) + db.create_tables() + + for i, camera in reconstruction.cameras.items(): + db.add_camera( + camera.model.value, + camera.width, + camera.height, + camera.params, + camera_id=i, + prior_focal_length=True, + ) + + for i, image in reconstruction.images.items(): + db.add_image(image.name, image.camera_id, image_id=i) + + db.commit() + db.close() + return {image.name: i for i, image in reconstruction.images.items()} + + +def import_features( + image_ids: Dict[str, int], database_path: Path, features_path: Path +): + # logger.info("Importing features into the database...") + logging.info("Importing features into the database...") + db = COLMAPDatabase.connect(database_path) + + for image_name, image_id in tqdm(image_ids.items()): + keypoints = get_keypoints(features_path, image_name) + keypoints += 0.5 # COLMAP origin + db.add_keypoints(image_id, keypoints) + + db.commit() + db.close() + + +def import_matches( + image_ids: Dict[str, int], + database_path: Path, + pairs_path: Path, + matches_path: Path, + min_match_score: Optional[float] = None, + skip_geometric_verification: bool = False, +): + # logger.info("Importing matches into the database...") + logging.info("Importing matches into the database...") + + with open(str(pairs_path), "r") as f: + pairs = [p.split() for p in f.readlines()] + + db = COLMAPDatabase.connect(database_path) + + matched = set() + for name0, name1 in tqdm(pairs): + id0, id1 = image_ids[name0], image_ids[name1] + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matches, scores = get_matches(matches_path, name0, name1) + if min_match_score: + matches = matches[scores > min_match_score] + db.add_matches(id0, id1, matches) + matched |= {(id0, id1), (id1, id0)} + + if skip_geometric_verification: + db.add_two_view_geometry(id0, id1, matches) + + db.commit() + db.close() + + +def estimation_and_geometric_verification( + database_path: Path, pairs_path: Path, verbose: bool = False +): + # logger.info("Performing geometric verification of the matches...") + logging.info("Performing geometric verification of the matches...") + with OutputCapture(verbose): + with pycolmap.ostream(): + pycolmap.verify_matches( + database_path, + pairs_path, + options=dict(ransac=dict(max_num_trials=20000, min_inlier_ratio=0.1)), + ) + + +def geometric_verification( + image_ids: Dict[str, int], + reference: pycolmap.Reconstruction, + database_path: Path, + features_path: Path, + pairs_path: Path, + matches_path: Path, + max_error: float = 4.0, +): + # logger.info("Performing geometric verification of the matches...") + logging.info("Performing geometric verification of the matches...") + + pairs = parse_retrieval(pairs_path) + db = COLMAPDatabase.connect(database_path) + + inlier_ratios = [] + matched = set() + for name0 in tqdm(pairs): + id0 = image_ids[name0] + image0 = reference.images[id0] + cam0 = reference.cameras[image0.camera_id] + kps0, noise0 = get_keypoints(features_path, name0, return_uncertainty=True) + noise0 = 1.0 if noise0 is None else noise0 + if len(kps0) > 0: + kps0 = np.stack(cam0.cam_from_img(kps0)) + else: + kps0 = np.zeros((0, 2)) + + for name1 in pairs[name0]: + id1 = image_ids[name1] + image1 = reference.images[id1] + cam1 = reference.cameras[image1.camera_id] + kps1, noise1 = get_keypoints(features_path, name1, return_uncertainty=True) + noise1 = 1.0 if noise1 is None else noise1 + if len(kps1) > 0: + kps1 = np.stack(cam1.cam_from_img(kps1)) + else: + kps1 = np.zeros((0, 2)) + + matches = get_matches(matches_path, name0, name1)[0] + + if len({(id0, id1), (id1, id0)} & matched) > 0: + continue + matched |= {(id0, id1), (id1, id0)} + + if matches.shape[0] == 0: + db.add_two_view_geometry(id0, id1, matches) + continue + + cam1_from_cam0 = image1.cam_from_world * image0.cam_from_world.inverse() + errors0, errors1 = compute_epipolar_errors( + cam1_from_cam0, kps0[matches[:, 0]], kps1[matches[:, 1]] + ) + valid_matches = np.logical_and( + errors0 <= cam0.cam_from_img_threshold(noise0 * max_error), + errors1 <= cam1.cam_from_img_threshold(noise1 * max_error), + ) + # TODO: We could also add E to the database, but we need + # to reverse the transformations if id0 > id1 in utils/database.py. + db.add_two_view_geometry(id0, id1, matches[valid_matches, :]) + inlier_ratios.append(np.mean(valid_matches)) + # logger.info( + logging.info( + "mean/med/min/max valid matches %.2f/%.2f/%.2f/%.2f%%.", + np.mean(inlier_ratios) * 100, + np.median(inlier_ratios) * 100, + np.min(inlier_ratios) * 100, + np.max(inlier_ratios) * 100, + ) + + db.commit() + db.close() + + +def run_triangulation( + model_path: Path, + database_path: Path, + image_dir: Path, + reference_model: pycolmap.Reconstruction, + verbose: bool = False, + options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + model_path.mkdir(parents=True, exist_ok=True) + # logger.info("Running 3D triangulation...") + logging.info("Running 3D triangulation...") + if options is None: + options = {} + with OutputCapture(verbose): + with pycolmap.ostream(): + reconstruction = pycolmap.triangulate_points( + reference_model, database_path, image_dir, model_path, options=options + ) + return reconstruction + + +def main( + sfm_dir: Path, + reference_sfm_model: Path, + image_dir: Path, + pairs: Path, + features: Path, + matches: Path, + skip_geometric_verification: bool = False, + estimate_two_view_geometries: bool = False, + min_match_score: Optional[float] = None, + verbose: bool = False, + mapper_options: Optional[Dict[str, Any]] = None, +) -> pycolmap.Reconstruction: + assert reference_sfm_model.exists(), reference_sfm_model + assert features.exists(), features + assert pairs.exists(), pairs + assert matches.exists(), matches + + sfm_dir.mkdir(parents=True, exist_ok=True) + database = sfm_dir / "database.db" + reference = pycolmap.Reconstruction(reference_sfm_model) + + image_ids = create_db_from_model(reference, database) + import_features(image_ids, database, features) + import_matches( + image_ids, + database, + pairs, + matches, + min_match_score, + skip_geometric_verification, + ) + if not skip_geometric_verification: + if estimate_two_view_geometries: + estimation_and_geometric_verification(database, pairs, verbose) + else: + geometric_verification( + image_ids, reference, database, features, pairs, matches + ) + reconstruction = run_triangulation( + sfm_dir, database, image_dir, reference, verbose, mapper_options + ) + # logger.info( + logging.info( + "Finished the triangulation with statistics:\n%s", reconstruction.summary() + ) + stats = reconstruction.summary() + with open(sfm_dir / 'statics.txt', 'w') as f: + f.write(stats + '\n') + + # logging.info(f'Statistics:\n{pprint.pformat(stats)}') + return reconstruction + + +def parse_option_args(args: List[str], default_options) -> Dict[str, Any]: + options = {} + for arg in args: + idx = arg.find("=") + if idx == -1: + raise ValueError("Options format: key1=value1 key2=value2 etc.") + key, value = arg[:idx], arg[idx + 1:] + if not hasattr(default_options, key): + raise ValueError( + f'Unknown option "{key}", allowed options and default values' + f" for {default_options.summary()}" + ) + value = eval(value) + target_type = type(getattr(default_options, key)) + if not isinstance(value, target_type): + raise ValueError( + f'Incorrect type for option "{key}":' f" {type(value)} vs {target_type}" + ) + options[key] = value + return options + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--sfm_dir", type=Path, required=True) + parser.add_argument("--reference_sfm_model", type=Path, required=True) + parser.add_argument("--image_dir", type=Path, required=True) + + parser.add_argument("--pairs", type=Path, required=True) + parser.add_argument("--features", type=Path, required=True) + parser.add_argument("--matches", type=Path, required=True) + + parser.add_argument("--skip_geometric_verification", action="store_true") + parser.add_argument("--min_match_score", type=float) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args().__dict__ + + main(**args) diff --git a/imcui/third_party/pram/localization/utils.py b/imcui/third_party/pram/localization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5861afceba6bed7518921145505b01caf66954 --- /dev/null +++ b/imcui/third_party/pram/localization/utils.py @@ -0,0 +1,83 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> utils +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 15:27 +==================================================''' +import numpy as np +from colmap_utils.read_write_model import qvec2rotmat + + +def read_query_info(query_fn: str, name_prefix='') -> dict: + results = {} + with open(query_fn, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split() + name, camera_model, width, height = l[:4] + params = np.array(l[4:], float) + info = (camera_model, int(width), int(height), params) + results[name_prefix + name] = info + print('Load {} query images'.format(len(results.keys()))) + return results + + +def quaternion_angular_error(q1, q2): + """ + angular error between two quaternions + :param q1: (4, ) + :param q2: (4, ) + :return: + """ + d = abs(np.dot(q1, q2)) + d = min(1.0, max(-1.0, d)) + theta = 2 * np.arccos(d) * 180 / np.pi + return theta + + +def compute_pose_error(pred_qcw, pred_tcw, gt_qcw, gt_tcw): + pred_Rcw = qvec2rotmat(qvec=pred_qcw) + pred_tcw = np.array(pred_tcw, float).reshape(3, 1) + pred_twc = -pred_Rcw.transpose() @ pred_tcw + + gt_Rcw = qvec2rotmat(gt_qcw) + gt_tcw = np.array(gt_tcw, float).reshape(3, 1) + gt_twc = -gt_Rcw.transpose() @ gt_tcw + + t_error_xyz = pred_twc - gt_twc + t_error = np.sqrt(np.sum(t_error_xyz ** 2)) + + q_error = quaternion_angular_error(q1=pred_qcw, q2=gt_qcw) + + return q_error, t_error + + +def read_retrieval_results(path): + output = {} + with open(path, "r") as f: + lines = f.readlines() + for p in lines: + p = p.strip("\n").split(" ") + + if p[1] == "no_match": + continue + if p[0] in output.keys(): + output[p[0]].append(p[1]) + else: + output[p[0]] = [p[1]] + return output + + +def read_gt_pose(path): + gt_poses = {} + with open(path, 'r') as f: + lines = f.readlines() + for l in lines: + l = l.strip().split(' ') + gt_poses[l[0]] = { + 'qvec': np.array([float(v) for v in l[1:5]], float), + 'tvec': np.array([float(v) for v in l[5:]], float), + } + + return gt_poses diff --git a/imcui/third_party/pram/localization/viewer.py b/imcui/third_party/pram/localization/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..33899f60ab362e240b7b0e6736a157a7aa041d31 --- /dev/null +++ b/imcui/third_party/pram/localization/viewer.py @@ -0,0 +1,548 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> viewer +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 05/03/2024 16:50 +==================================================''' +import cv2 +import numpy as np +import pypangolin as pangolin +from OpenGL.GL import * +import time +import threading +from colmap_utils.read_write_model import qvec2rotmat +from tools.common import resize_image_with_padding +from localization.frame import Frame + + +class Viewer: + default_config = { + 'image_size_indoor': 0.1, + 'image_line_width_indoor': 1, + + 'image_size_outdoor': 1, + 'image_line_width_outdoor': 3, + + 'point_size_indoor': 1, + 'point_size_outdoor': 1, + + 'image_width': 640, + 'image_height': 480, + + 'viewpoint_x': 0, + 'viewpoint_y': -1, + 'viewpoint_z': -5, + 'viewpoint_F': 512, + + 'scene': 'indoor', + } + + def __init__(self, locMap, seg_color, config={}): + self.config = {**self.default_config, **config} + self.viewpoint_x = self.config['viewpoint_x'] + self.viewpoint_y = self.config['viewpoint_y'] + self.viewpoint_z = self.config['viewpoint_z'] + self.viewpoint_F = self.config['viewpoint_F'] + self.img_width = self.config['image_width'] + self.img_height = self.config['image_height'] + + if self.config['scene'] == 'indoor': + self.image_size = self.config['image_size_indoor'] + self.image_line_width = self.config['image_line_width_indoor'] + self.point_size = self.config['point_size_indoor'] + + else: + self.image_size = self.config['image_size_outdoor'] + self.image_line_width = self.config['image_line_width_outdoor'] + self.point_size = self.config['point_size_outdoor'] + self.viewpoint_z = -150 + + self.locMap = locMap + self.seg_colors = seg_color + + # current camera pose + self.frame = None + self.Tcw = np.eye(4, dtype=float) + self.Twc = np.linalg.inv(self.Tcw) + self.gt_Tcw = None + self.gt_Twc = None + + self.scene = None + self.current_vrf_id = None + self.reference_frame_ids = None + self.subMap = None + self.seg_point_clouds = None + self.point_clouds = None + + self.start_seg_id = 1 + self.stop = False + + self.refinement = False + self.tracking = False + + # time + self.time_feat = np.NAN + self.time_rec = np.NAN + self.time_loc = np.NAN + self.time_ref = np.NAN + + # image + self.image_rec = None + + def draw_3d_points_white(self): + if self.point_clouds is None: + return + + point_size = self.point_size * 0.5 + glColor4f(0.9, 0.95, 1.0, 0.6) + glPointSize(point_size) + pangolin.glDrawPoints(self.point_clouds) + + def draw_seg_3d_points(self): + if self.seg_point_clouds is None: + return + for sid in self.seg_point_clouds.keys(): + xyzs = self.seg_point_clouds[sid] + point_size = self.point_size * 0.5 + bgr = self.seg_colors[sid + self.start_seg_id + 1] + glColor3f(bgr[2] / 255, bgr[1] / 255, bgr[0] / 255) + glPointSize(point_size) + pangolin.glDrawPoints(xyzs) + + def draw_ref_3d_points(self, use_seg_color=False): + if self.reference_frame_ids is None: + return + + ref_point3D_ids = [] + for fid in self.reference_frame_ids: + pids = self.subMap.reference_frames[fid].point3D_ids + ref_point3D_ids.extend(list(pids)) + + ref_point3D_ids = np.unique(ref_point3D_ids).tolist() + + point_size = self.point_size * 5 + glPointSize(point_size) + glBegin(GL_POINTS) + + for pid in ref_point3D_ids: + if pid not in self.subMap.point3Ds.keys(): + continue + xyz = self.subMap.point3Ds[pid].xyz + rgb = self.subMap.point3Ds[pid].rgb + sid = self.subMap.point3Ds[pid].seg_id + if use_seg_color: + bgr = self.seg_colors[sid + self.start_seg_id + 1] + glColor3f(bgr[2] / 255, bgr[1] / 255, bgr[0] / 255) + else: + glColor3f(rgb[0] / 255, rgb[1] / 255, rgb[2] / 255) + + glVertex3f(xyz[0], xyz[1], xyz[2]) + + glEnd() + + def draw_vrf_frames(self): + if self.subMap is None: + return + w = self.image_size * 1.0 + image_line_width = self.image_line_width * 1.0 + h = w * 0.75 + z = w * 0.6 + for sid in self.subMap.seg_ref_frame_ids.keys(): + frame_id = self.subMap.seg_ref_frame_ids[sid][0] + qvec = self.subMap.reference_frames[frame_id].qvec + tcw = self.subMap.reference_frames[frame_id].tvec + + Rcw = qvec2rotmat(qvec) + + twc = -Rcw.T @ tcw + Rwc = Rcw.T + + Twc = np.column_stack((Rwc, twc)) + Twc = np.vstack((Twc, (0, 0, 0, 1))) + + glPushMatrix() + + glMultMatrixf(Twc.T) + + glLineWidth(image_line_width) + glColor3f(1, 0, 0) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def draw_current_vrf_frame(self): + if self.current_vrf_id is None: + return + qvec = self.subMap.reference_frames[self.current_vrf_id].qvec + tcw = self.subMap.reference_frames[self.current_vrf_id].tvec + Rcw = qvec2rotmat(qvec) + twc = -Rcw.T @ tcw + Rwc = Rcw.T + Twc = np.column_stack((Rwc, twc)) + Twc = np.vstack((Twc, (0, 0, 0, 1))) + + camera_line_width = self.image_line_width * 2 + w = self.image_size * 2 + h = w * 0.75 + z = w * 0.6 + + glPushMatrix() + + glMultMatrixf(Twc.T) # note the .T + + glLineWidth(camera_line_width) + glColor3f(1, 0, 0) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def draw_current_frame(self, Tcw, color=(0, 1.0, 0)): + Twc = np.linalg.inv(Tcw) + + camera_line_width = self.image_line_width * 2 + w = self.image_size * 2 + h = w * 0.75 + z = w * 0.6 + + glPushMatrix() + + glMultMatrixf(Twc.T) # not the .T + + glLineWidth(camera_line_width) + glColor3f(color[0], color[1], color[2]) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def draw_ref_frames(self): + if self.reference_frame_ids is None: + return + w = self.image_size * 1.5 + image_line_width = self.image_line_width * 1.5 + h = w * 0.75 + z = w * 0.6 + for fid in self.reference_frame_ids: + qvec = self.subMap.reference_frames[fid].qvec + tcw = self.subMap.reference_frames[fid].tvec + Rcw = qvec2rotmat(qvec) + + twc = -Rcw.T @ tcw + Rwc = Rcw.T + + Twc = np.column_stack((Rwc, twc)) + Twc = np.vstack((Twc, (0, 0, 0, 1))) + + glPushMatrix() + + glMultMatrixf(Twc.T) + + glLineWidth(image_line_width) + glColor3f(100 / 255, 140 / 255, 17 / 255) + glBegin(GL_LINES) + glVertex3f(0, 0, 0) + glVertex3f(w, h, z) + glVertex3f(0, 0, 0) + glVertex3f(w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, -h, z) + glVertex3f(0, 0, 0) + glVertex3f(-w, h, z) + + glVertex3f(w, h, z) + glVertex3f(w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(-w, -h, z) + + glVertex3f(-w, h, z) + glVertex3f(w, h, z) + + glVertex3f(-w, -h, z) + glVertex3f(w, -h, z) + glEnd() + + glPopMatrix() + + def terminate(self): + lock = threading.Lock() + lock.acquire() + self.stop = True + lock.release() + + def update_point_clouds(self): + # for fast drawing + seg_point_clouds = {} + point_clouds = [] + for pid in self.subMap.point3Ds.keys(): + sid = self.subMap.point3Ds[pid].seg_id + xyz = self.subMap.point3Ds[pid].xyz + if sid in seg_point_clouds.keys(): + seg_point_clouds[sid].append(xyz.reshape(3, 1)) + else: + seg_point_clouds[sid] = [xyz.reshape(3, 1)] + + point_clouds.append(xyz.reshape(3, 1)) + + self.seg_point_clouds = seg_point_clouds + self.point_clouds = point_clouds + + def update(self, curr_frame: Frame): + lock = threading.Lock() + lock.acquire() + + # self.frame = curr_frame + self.current_vrf_id = curr_frame.reference_frame_id + self.reference_frame_ids = [self.current_vrf_id] + + # self.reference_frame_ids = curr_frame.refinement_reference_frame_ids + # if self.reference_frame_ids is None: + # self.reference_frame_ids = [self.current_vrf_id] + self.subMap = self.locMap.sub_maps[curr_frame.matched_scene_name] + self.start_seg_id = self.locMap.scene_name_start_sid[curr_frame.matched_scene_name] + + if self.scene is None or self.scene != curr_frame.matched_scene_name: + self.scene = curr_frame.matched_scene_name + self.update_point_clouds() + + if curr_frame.qvec is not None: + Rcw = qvec2rotmat(curr_frame.qvec) + Tcw = np.column_stack((Rcw, curr_frame.tvec)) + self.Tcw = np.vstack((Tcw, (0, 0, 0, 1))) + Rwc = Rcw.T + twc = -Rcw.T @ curr_frame.tvec + Twc = np.column_stack((Rwc, twc)) + self.Twc = np.vstack((Twc, (0, 0, 0, 1))) + + if curr_frame.gt_qvec is not None: + gt_Rcw = qvec2rotmat(curr_frame.gt_qvec) + gt_Tcw = np.column_stack((gt_Rcw, curr_frame.gt_tvec)) + self.gt_Tcw = np.vstack((gt_Tcw, (0, 0, 0, 1))) + gt_Rwc = gt_Rcw.T + gt_twc = -gt_Rcw.T @ curr_frame.gt_tvec + gt_Twc = np.column_stack((gt_Rwc, gt_twc)) + self.gt_Twc = np.vstack((gt_Twc, (0, 0, 0, 1))) + else: + self.gt_Tcw = None + self.gt_Twc = None + + # update time + self.time_feat = curr_frame.time_feat + self.time_rec = curr_frame.time_rec + self.time_loc = curr_frame.time_loc + self.time_ref = curr_frame.time_ref + + # update image + image_rec_inlier = np.hstack([curr_frame.image_rec, curr_frame.image_inlier]) + image_rec_inlier = resize_image_with_padding(image=image_rec_inlier, nw=self.img_width * 2, nh=self.img_height) + image_matching = resize_image_with_padding(image=curr_frame.image_matching, nw=self.img_width * 2, + nh=self.img_height) + image_rec_matching_inliers = resize_image_with_padding(image=np.vstack([image_rec_inlier, image_matching]), + nw=self.img_width * 2, nh=self.img_height * 2) + + self.image_rec = cv2.cvtColor(image_rec_matching_inliers, cv2.COLOR_BGR2RGB) + lock.release() + + def run(self): + pangolin.CreateWindowAndBind("Map reviewer", 640, 480) + glEnable(GL_DEPTH_TEST) + glEnable(GL_BLEND) + glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) + + pangolin.CreatePanel("menu").SetBounds(pangolin.Attach(0), + pangolin.Attach(1), + pangolin.Attach(0), + # pangolin.Attach.Pix(-175), + pangolin.Attach.Pix(175), + # pangolin.Attach(1) + ) + + menu = pangolin.Var("menu") + menu.Tracking = (False, pangolin.VarMeta(toggle=True)) + menu.FollowCamera = (True, pangolin.VarMeta(toggle=True)) + menu.ShowPoints = (True, pangolin.VarMeta(toggle=True)) + menu.ShowSegs = (False, pangolin.VarMeta(toggle=True)) + menu.ShowRefSegs = (True, pangolin.VarMeta(toggle=True)) + menu.ShowRefPoints = (False, pangolin.VarMeta(toggle=True)) + menu.ShowVRFFrame = (True, pangolin.VarMeta(toggle=True)) + menu.ShowAllVRFs = (False, pangolin.VarMeta(toggle=True)) + menu.ShowRefFrames = (False, pangolin.VarMeta(toggle=True)) + + menu.Refinement = (self.refinement, pangolin.VarMeta(toggle=True)) + + menu.featTime = 'NaN' + menu.recTime = 'NaN' + menu.locTime = 'NaN' + menu.refTime = 'NaN' + menu.totalTime = 'NaN' + + pm = pangolin.ProjectionMatrix(640, 480, self.viewpoint_F, self.viewpoint_F, 320, 240, 0.1, + 10000) + + # /camera position,viewpoint position,axis direction + mv = pangolin.ModelViewLookAt(self.viewpoint_x, + self.viewpoint_y, + self.viewpoint_z, + 0, 0, 0, + # 0.0, -1.0, 0.0, + pangolin.AxisZ, + ) + + s_cam = pangolin.OpenGlRenderState(pm, mv) + # Attach bottom, Attach top, Attach left, Attach right, + scale = 0.42 + d_img_rec = pangolin.Display('image_rec').SetBounds(pangolin.Attach(1 - scale), + pangolin.Attach(1), + pangolin.Attach( + 1 - 0.3), + pangolin.Attach(1), + self.img_width / self.img_height + ) # .SetLock(0, 1) + + handler = pangolin.Handler3D(s_cam) + + d_cam = pangolin.Display('3D').SetBounds( + pangolin.Attach(0), # bottom + pangolin.Attach(1), # top + pangolin.Attach.Pix(175), # left + # pangolin.Attach.Pix(0), # left + pangolin.Attach(1), # right + -640 / 480, # aspect + ).SetHandler(handler) + + d_img_rec_texture = pangolin.GlTexture(self.img_width * 2, self.img_height * 2, GL_RGB, False, 0, GL_RGB, + GL_UNSIGNED_BYTE) + while not pangolin.ShouldQuit() and not self.stop: + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) + + # glClearColor(1.0, 1.0, 1.0, 1.0) + glClearColor(0.0, 0.0, 0.0, 1.0) + + d_cam.Activate(s_cam) + if menu.FollowCamera: + s_cam.Follow(pangolin.OpenGlMatrix(self.Twc.astype(np.float32)), follow=True) + + # pangolin.glDrawColouredCube() + if menu.ShowPoints: + self.draw_3d_points_white() + + if menu.ShowRefPoints: + self.draw_ref_3d_points(use_seg_color=False) + if menu.ShowRefSegs: + self.draw_ref_3d_points(use_seg_color=True) + + if menu.ShowSegs: + self.draw_seg_3d_points() + + if menu.ShowAllVRFs: + self.draw_vrf_frames() + + if menu.ShowRefFrames: + self.draw_ref_frames() + + if menu.ShowVRFFrame: + self.draw_current_vrf_frame() + + if menu.Refinement: + self.refinement = True + else: + self.refinement = False + + if menu.Tracking: + self.tracking = True + else: + self.tracking = False + + self.draw_current_frame(Tcw=self.Tcw) + + if self.gt_Tcw is not None: # draw gt pose with color (0, 0, 1.0) + self.draw_current_frame(Tcw=self.gt_Tcw, color=(0., 0., 1.0)) + + d_img_rec.Activate() + glColor4f(1, 1, 1, 1) + + if self.image_rec is not None: + d_img_rec_texture.Upload(self.image_rec, GL_RGB, GL_UNSIGNED_BYTE) + d_img_rec_texture.RenderToViewportFlipY() + + time_total = 0 + if self.time_feat != np.NAN: + menu.featTime = '{:.2f}s'.format(self.time_feat) + time_total = time_total + self.time_feat + if self.time_rec != np.NAN: + menu.recTime = '{:.2f}s'.format(self.time_rec) + time_total = time_total + self.time_rec + if self.time_loc != np.NAN: + menu.locTime = '{:.2f}s'.format(self.time_loc) + time_total = time_total + self.time_loc + if self.time_ref != np.NAN: + menu.refTime = '{:.2f}s'.format(self.time_ref) + time_total = time_total + self.time_ref + menu.totalTime = '{:.2f}s'.format(time_total) + + time.sleep(50 / 1000) + + pangolin.FinishFrame() diff --git a/imcui/third_party/pram/main.py b/imcui/third_party/pram/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0f32b1e9087dcf7edd152911cf09bef93f0555d5 --- /dev/null +++ b/imcui/third_party/pram/main.py @@ -0,0 +1,228 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> train +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:26 +==================================================''' +import argparse +import os +import os.path as osp +import torch +import torchvision.transforms.transforms as tvt +import yaml +import torch.utils.data as Data +import torch.multiprocessing as mp +import torch.distributed as dist + +from nets.segnet import SegNet +from nets.segnetvit import SegNetViT +from dataset.utils import collect_batch +from dataset.get_dataset import compose_datasets +from tools.common import torch_set_gpu +from trainer import Trainer + +from nets.sfd2 import ResNet4x, DescriptorCompressor +from nets.superpoint import SuperPoint + +torch.set_grad_enabled(True) + +parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--config', type=str, required=True, help='config of specifications') +parser.add_argument('--landmark_path', type=str, default=None, help='path of landmarks') + + +def load_feat_network(config): + if config['feature'] == 'spp': + net = SuperPoint(config={ + 'weight_path': '/scratches/flyer_2/fx221/Research/Code/third_weights/superpoint_v1.pth', + }).eval() + elif config['feature'] == 'resnet4x': + net = ResNet4x(inputdim=3, outdim=128) + net.load_state_dict( + torch.load('weights/sfd2_20230511_210205_resnet4x.79.pth', map_location='cpu')['state_dict'], + strict=True) + net.eval() + else: + print('Please input correct feature {:s}'.format(config['feature'])) + net = None + + if config['feat_dim'] != 128: + desc_compressor = DescriptorCompressor(inputdim=128, outdim=config['feat_dim']).eval() + if config['feat_dim'] == 64: + desc_compressor.load_state_dict( + torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O64.pth', + map_location='cpu'), + strict=True) + elif config['feat_dim'] == 32: + desc_compressor.load_state_dict( + torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O32.pth', + map_location='cpu'), + strict=True) + else: + desc_compressor = None + else: + desc_compressor = None + return net, desc_compressor + + +def get_model(config): + desc_dim = 256 if config['feature'] == 'spp' else 128 + if config['use_mid_feature']: + desc_dim = 256 + model_config = { + 'network': { + 'descriptor_dim': desc_dim, + 'n_layers': config['layers'], + 'ac_fn': config['ac_fn'], + 'norm_fn': config['norm_fn'], + 'n_class': config['n_class'], + 'output_dim': config['output_dim'], + 'with_cls': config['with_cls'], + 'with_sc': config['with_sc'], + 'with_score': config['with_score'], + } + } + + if config['network'] == 'segnet': + model = SegNet(model_config.get('network', {})) + config['with_cls'] = False + elif config['network'] == 'segnetvit': + model = SegNetViT(model_config.get('network', {})) + config['with_cls'] = False + else: + raise 'ERROR! {:s} model does not exist'.format(config['network']) + + if config['local_rank'] == 0: + if config['weight_path'] is not None: + state_dict = torch.load(osp.join(config['save_path'], config['weight_path']), map_location='cpu')['model'] + model.load_state_dict(state_dict, strict=True) + print('Load weight from {:s}'.format(osp.join(config['save_path'], config['weight_path']))) + + if config['resume_path'] is not None and not config['eval']: # only for training + model.load_state_dict( + torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'], + strict=True) + print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path']))) + + return model + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms): + print('In train_DDP..., rank: ', rank) + torch.cuda.set_device(rank) + + device = torch.device(f'cuda:{rank}') + if feat_model is not None: + feat_model.to(device) + model.to(device) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + setup(rank=rank, world_size=world_size) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, + shuffle=True, + rank=rank, + num_replicas=world_size, + drop_last=True, # important? + ) + train_loader = torch.utils.data.DataLoader(train_set, + batch_size=config['batch_size'] // world_size, + num_workers=config['workers'] // world_size, + # num_workers=1, + pin_memory=True, + # persistent_workers=True, + shuffle=False, # must be False + drop_last=True, + collate_fn=collect_batch, + prefetch_factor=4, + sampler=train_sampler) + config['local_rank'] = rank + + if rank == 0: + test_set = test_set + else: + test_set = None + + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set, + config=config, img_transforms=img_transforms) + trainer.train() + + +if __name__ == '__main__': + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = yaml.load(f, Loader=yaml.Loader) + torch_set_gpu(gpus=config['gpu']) + if config['local_rank'] == 0: + print(config) + + if config['feature'] == 'spp': + img_transforms = None + else: + img_transforms = [] + img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + img_transforms = tvt.Compose(img_transforms) + feat_model, desc_compressor = load_feat_network(config=config) + + dataset = config['dataset'] + if config['eval'] or config['loc']: + if not config['online']: + from localization.loc_by_rec_eval import loc_by_rec_eval + + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1) + config['n_class'] = test_set.n_class + + model = get_model(config=config) + loc_by_rec_eval(rec_model=model.cuda().eval(), + loader=test_set, + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) + else: + from localization.loc_by_rec_online import loc_by_rec_online + + model = get_model(config=config) + loc_by_rec_online(rec_model=model.cuda().eval(), + local_feat=feat_model.cuda().eval(), + config=config, img_transforms=img_transforms) + exit(0) + + train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None) + if config['do_eval']: + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None) + else: + test_set = None + config['n_class'] = train_set.n_class + model = get_model(config=config) + + if not config['with_dist'] or len(config['gpu']) == 1: + config['with_dist'] = False + model = model.cuda() + train_loader = Data.DataLoader(dataset=train_set, + shuffle=True, + batch_size=config['batch_size'], + drop_last=True, + collate_fn=collect_batch, + num_workers=config['workers']) + if test_set is not None: + test_loader = Data.DataLoader(dataset=test_set, + shuffle=False, + batch_size=1, + drop_last=False, + collate_fn=collect_batch, + num_workers=4) + else: + test_loader = None + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader, + config=config, img_transforms=img_transforms) + trainer.train() + else: + mp.spawn(train_DDP, nprocs=len(config['gpu']), + args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms), + join=True) diff --git a/imcui/third_party/pram/nets/adagml.py b/imcui/third_party/pram/nets/adagml.py new file mode 100644 index 0000000000000000000000000000000000000000..c6980334a8980a105dc91d4586b3a342fb4e648e --- /dev/null +++ b/imcui/third_party/pram/nets/adagml.py @@ -0,0 +1,536 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> adagml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 11/02/2024 14:29 +==================================================''' +import torch +from torch import nn +import torch.nn.functional as F +from typing import Callable +import time +import numpy as np + +torch.backends.cudnn.deterministic = True + +eps = 1e-8 + + +def arange_like(x, dim: int): + return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 + + +def dual_softmax(M, dustbin): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) + return torch.exp(score) + + +def sinkhorn(M, r, c, iteration): + p = torch.softmax(M, dim=-1) + u = torch.ones_like(r) + v = torch.ones_like(c) + for _ in range(iteration): + u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) + v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) + p = p * u.unsqueeze(-1) * v.unsqueeze(-2) + return p + + +def sink_algorithm(M, dustbin, iteration): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda') + r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda') + c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1) + p = sinkhorn(M, r, c, iteration) + return p + + +def normalize_keypoints(kpts, image_shape): + """ Normalize keypoints locations based on image image_shape""" + _, _, height, width = image_shape + one = kpts.new_tensor(1) + size = torch.stack([one * width, one * height])[None] + center = size / 2 + scaling = size.max(1, keepdim=True).values * 0.7 + return (kpts - center[:, None, :]) / scaling[:, None, :] + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb( + freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, + gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ encode position vector """ + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(3, 32), + nn.LayerNorm(32, elementwise_affine=True), + nn.GELU(), + nn.Linear(32, 64), + nn.LayerNorm(64, elementwise_affine=True), + nn.GELU(), + nn.Linear(64, 128), + nn.LayerNorm(128, elementwise_affine=True), + nn.GELU(), + nn.Linear(128, 256), + ) + + def forward(self, kpts, scores): + inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1] + return self.encoder(torch.cat(inputs, dim=-1)) + + +class PoolingLayer(nn.Module): + def __init__(self, hidden_dim: int, score_dim: int = 2): + super().__init__() + + self.score_enc = nn.Sequential( + nn.Linear(score_dim, hidden_dim), + nn.LayerNorm(hidden_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), + ) + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.predict = nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.LayerNorm(hidden_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(hidden_dim, 1), + ) + + def forward(self, x, score): + score_ = self.score_enc(score) + x_ = self.proj(x) + confidence = self.predict(torch.cat([x_, score_], -1)) + confidence = torch.sigmoid(confidence) + + return confidence + + +class Attention(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + s = q.shape[-1] ** -0.5 + attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1) + return torch.einsum('...ij,...jd->...id', attn, v), torch.mean(torch.mean(attn, dim=1), dim=1) + + +class SelfMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + + assert feat_dim % num_heads == 0 + self.head_dim = feat_dim // num_heads + self.qkv = nn.Linear(feat_dim, hidden_dim * 3) + self.attn = Attention() + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim) + ) + + def forward_(self, x, encoding=None): + qkv = self.qkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + if encoding is not None: + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + attn, attn_score = self.attn(q, k, v) + message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2)) + return x + self.mlp(torch.cat([x, message], -1)), attn_score + + def forward(self, x0, x1, encoding0=None, encoding1=None): + x0_, att_score00 = self.forward_(x=x0, encoding=encoding0) + x1_, att_score11 = self.forward_(x=x1, encoding=encoding1) + return x0_, x1_, att_score00, att_score11 + + +class CrossMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + assert hidden_dim % num_heads == 0 + dim_head = hidden_dim // num_heads + self.scale = dim_head ** -0.5 + self.to_qk = nn.Linear(feat_dim, hidden_dim) + self.to_v = nn.Linear(feat_dim, hidden_dim) + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim), + ) + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward(self, x0, x1): + qk0 = self.to_qk(x0) + qk1 = self.to_qk(x1) + v0 = self.to_v(x0) + v1 = self.to_v(x1) + + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.num_heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1)) + + qk0, qk1 = qk0 * self.scale ** 0.5, qk1 * self.scale ** 0.5 + sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1) + m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0) + + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), + m0, m1) + m0, m1 = self.map_(self.proj, m0, m1) + x0 = x0 + self.mlp(torch.cat([x0, m0], -1)) + x1 = x1 + self.mlp(torch.cat([x1, m1], -1)) + return x0, x1, torch.mean(torch.mean(attn10, dim=1), dim=1), torch.mean(torch.mean(attn01, dim=1), dim=1) + + +class AdaGML(nn.Module): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': True, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + 'min_confidence': 0.9, + + 'classification_background_weight': 0.05, + 'pretrained': True, + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + self.n_layers = self.config['n_layers'] + self.first_layer_pooling = 0 + self.n_min_tokens = self.config['n_min_tokens'] + self.min_confidence = self.config['min_confidence'] + self.classification_background_weight = self.config['classification_background_weight'] + + self.with_sinkhorn = self.config['with_sinkhorn'] + self.match_threshold = self.config['match_threshold'] + self.sinkhorn_iterations = self.config['sinkhorn_iterations'] + + self.input_proj = nn.Linear(self.config['descriptor_dim'], self.config['hidden_dim']) + + self.self_attn = nn.ModuleList( + [SelfMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + self.cross_attn = nn.ModuleList( + [CrossMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + + head_dim = self.config['hidden_dim'] // 4 + self.poseenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim) + self.out_proj = nn.ModuleList( + [nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) for _ in range(self.n_layers)] + ) + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + + self.pooling = nn.ModuleList( + [PoolingLayer(score_dim=2, hidden_dim=self.config['hidden_dim']) for _ in range(self.n_layers)] + ) + # self.pretrained = config['pretrained'] + # if self.pretrained: + # bin_score.requires_grad = False + # for m in [self.input_proj, self.out_proj, self.poseenc, self.self_attn, self.cross_attn]: + # for p in m.parameters(): + # p.requires_grad = False + + def forward(self, data, mode=0): + if not self.training: + if mode == 0: + return self.produce_matches(data=data) + else: + return self.run(data=data) + return self.forward_train(data=data) + + def forward_train(self, data: dict, p=0.2, **kwargs): + pass + + def produce_matches(self, data: dict, p: float = 0.2, **kwargs): + desc0, desc1 = data['descriptors0'], data['descriptors1'] + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + scores0, scores1 = data['scores0'], data['scores1'] + + # Keypoint normalization. + if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): + norm_kpts0 = data['norm_keypoints0'] + norm_kpts1 = data['norm_keypoints1'] + elif 'image0' in data.keys() and 'image1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape) + norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape) + elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']) + norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + desc0 = desc0.detach() # [B, N, D] + desc1 = desc1.detach() + + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + enc0 = self.poseenc(norm_kpts0) + enc1 = self.poseenc(norm_kpts1) + + nI = self.config['n_layers'] + nB = desc0.shape[0] + m = desc0.shape[1] + n = desc1.shape[1] + dev = desc0.device + + ind0 = torch.arange(0, m, device=dev)[None] + ind1 = torch.arange(0, n, device=dev)[None] + + do_pooling = True + + for ni in range(nI): + desc0, desc1, att_score00, att_score11 = self.self_attn[ni](desc0, desc1, enc0, enc1) + desc0, desc1, att_score01, att_score10 = self.cross_attn[ni](desc0, desc1) + + att_score0 = torch.cat([att_score00.unsqueeze(-1), att_score01.unsqueeze(-1)], dim=-1) + att_score1 = torch.cat([att_score11.unsqueeze(-1), att_score10.unsqueeze(-1)], dim=-1) + + conf0 = self.pooling[ni](desc0, att_score0).squeeze(-1) + conf1 = self.pooling[ni](desc1, att_score1).squeeze(-1) + + if do_pooling and ni >= 1: + if desc0.shape[1] >= self.n_min_tokens: + mask0 = conf0 > self.confidence_threshold(layer_index=ni) + ind0 = ind0[mask0][None] + desc0 = desc0[mask0][None] + enc0 = enc0[:, :, mask0][:, None] + + if desc1.shape[1] >= self.n_min_tokens: + mask1 = conf1 > self.confidence_threshold(layer_index=ni) + ind1 = ind1[mask1][None] + desc1 = desc1[mask1][None] + enc1 = enc1[:, :, mask1][:, None] + + # print('pooling: ', ni, desc0.shape, desc1.shape) + # print('ni: {:d}: pooling: {:.4f}'.format(ni, time.time() - t_start)) + # t_start = time.time() + if self.check_if_stop(confidences0=conf0, confidences1=conf1, layer_index=ni, num_points=m + n): + # print('ni:{:d}: checking: {:.4f}'.format(ni, time.time() - t_start)) + break + + if ni == nI: ni = nI - 1 + d = desc0.shape[-1] + mdesc0 = self.out_proj[ni](desc0) / d ** .25 + mdesc1 = self.out_proj[ni](desc1) / d ** .25 + + dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + valid = indices0 > -1 + m_indices0 = torch.where(valid)[1] + m_indices1 = indices0[valid] + + mind0 = ind0[0, m_indices0] + mind1 = ind1[0, m_indices1] + + indices0_full = torch.full((nB, m), -1, device=dev, dtype=indices0.dtype) + indices0_full[:, mind0] = mind1 + + mscores0_full = torch.zeros((nB, m), device=dev) + mscores0_full[:, ind0] = mscores0 + + indices0 = indices0_full + mscores0 = mscores0_full + + output = { + 'matches0': indices0, # use -1 for invalid match + # 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + } + + return output + + def run(self, data, p=0.2): + desc0 = data['desc1'] + # print('desc0: ', torch.sum(desc0 ** 2, dim=-1)) + # desc0 = torch.nn.functional.normalize(desc0, dim=-1) + desc0 = desc0.detach() + + desc1 = data['desc2'] + # desc1 = torch.nn.functional.normalize(desc1, dim=-1) + desc1 = desc1.detach() + + kpts0 = data['x1'][:, :, :2] + kpts1 = data['x2'][:, :, :2] + # kpts0 = normalize_keypoints(kpts=kpts0, image_shape=data['image_shape1']) + # kpts1 = normalize_keypoints(kpts=kpts1, image_shape=data['image_shape2']) + scores0 = data['x1'][:, :, -1] + scores1 = data['x2'][:, :, -1] + + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + enc0 = self.poseenc(kpts0) + enc1 = self.poseenc(kpts1) + + nB = desc0.shape[0] + nI = self.n_layers + m, n = desc0.shape[1], desc1.shape[1] + dev = desc0.device + ind0 = torch.arange(0, m, device=dev)[None] + ind1 = torch.arange(0, n, device=dev)[None] + do_pooling = True + + for ni in range(nI): + desc0, desc1, att_score00, att_score11 = self.self_attn[ni](desc0, desc1, enc0, enc1) + desc0, desc1, att_score01, att_score10 = self.cross_attn[ni](desc0, desc1) + + att_score0 = torch.cat([att_score00.unsqueeze(-1), att_score01.unsqueeze(-1)], dim=-1) + att_score1 = torch.cat([att_score11.unsqueeze(-1), att_score10.unsqueeze(-1)], dim=-1) + + conf0 = self.pooling[ni](desc0, att_score0).squeeze(-1) + conf1 = self.pooling[ni](desc1, att_score1).squeeze(-1) + + if do_pooling and ni >= 1: + if desc0.shape[1] >= self.n_min_tokens: + mask0 = conf0 > self.confidence_threshold(layer_index=ni) + ind0 = ind0[mask0][None] + desc0 = desc0[mask0][None] + enc0 = enc0[:, :, mask0][:, None] + + if desc1.shape[1] >= self.n_min_tokens: + mask1 = conf1 > self.confidence_threshold(layer_index=ni) + ind1 = ind1[mask1][None] + desc1 = desc1[mask1][None] + enc1 = enc1[:, :, mask1][:, None] + if desc0.shape[1] <= 5 or desc1.shape[1] <= 5: + return { + 'index0': torch.zeros(size=(1,), device=desc0.device).long(), + 'index1': torch.zeros(size=(1,), device=desc1.device).long(), + } + + if self.check_if_stop(confidences0=conf0, confidences1=conf1, layer_index=ni, + num_points=m + n): + break + + if ni == nI: ni = -1 + d = desc0.shape[-1] + mdesc0 = self.out_proj[ni](desc0) / d ** .25 + mdesc1 = self.out_proj[ni](desc1) / d ** .25 + + dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + valid = indices0 > -1 + m_indices0 = torch.where(valid)[1] + m_indices1 = indices0[valid] + + mind0 = ind0[0, m_indices0] + mind1 = ind1[0, m_indices1] + + output = { + # 'p': score, + 'index0': mind0, + 'index1': mind1, + } + + return output + + def compute_score(self, dist, dustbin, iteration): + if self.with_sinkhorn: + score = sink_algorithm(M=dist, dustbin=dustbin, + iteration=iteration) # [nI * nB, N, M] + else: + score = dual_softmax(M=dist, dustbin=dustbin) + return score + + def compute_matches(self, scores, p=0.2): + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + indices0, indices1 = max0.indices, max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = torch.where(mutual0, max0.values, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + # valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > p) + valid1 = mutual1 & valid0.gather(1, indices1) + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + return indices0, indices1, mscores0, mscores1 + + def confidence_threshold(self, layer_index: int): + """scaled confidence threshold""" + # threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers) + threshold = 0.5 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers) + return np.clip(threshold, 0, 1) + + def check_if_stop(self, + confidences0: torch.Tensor, + confidences1: torch.Tensor, + layer_index: int, num_points: int) -> torch.Tensor: + """ evaluate stopping condition""" + confidences = torch.cat([confidences0, confidences1], -1) + threshold = self.confidence_threshold(layer_index) + pos = 1.0 - (confidences < threshold).float().sum() / num_points + # print('check_stop: ', pos) + return pos > 0.95 + + def stop_iteration(self, m_last, n_last, m_current, n_current, confidence=0.975): + prob = (m_current + n_current) / (m_last + n_last) + # print('prob: ', prob) + return prob > confidence diff --git a/imcui/third_party/pram/nets/gm.py b/imcui/third_party/pram/nets/gm.py new file mode 100644 index 0000000000000000000000000000000000000000..232a364ce60acb49cb6af26b72a881cbec18c1a9 --- /dev/null +++ b/imcui/third_party/pram/nets/gm.py @@ -0,0 +1,264 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> gm +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 10:47 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from nets.layers import KeypointEncoder, AttentionalPropagation +from nets.utils import normalize_keypoints, arange_like + +eps = 1e-8 + + +def dual_softmax(M, dustbin): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) + return torch.exp(score) + + +def sinkhorn(M, r, c, iteration): + p = torch.softmax(M, dim=-1) + u = torch.ones_like(r) + v = torch.ones_like(c) + for _ in range(iteration): + u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) + v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) + p = p * u.unsqueeze(-1) * v.unsqueeze(-2) + return p + + +def sink_algorithm(M, dustbin, iteration): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda') + r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda') + c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1) + p = sinkhorn(M, r, c, iteration) + return p + + +class AttentionalGNN(nn.Module): + def __init__(self, feature_dim: int, layer_names: list, hidden_dim: int = 256, ac_fn: str = 'relu', + norm_fn: str = 'bn'): + super().__init__() + self.layers = nn.ModuleList([ + AttentionalPropagation(feature_dim=feature_dim, num_heads=4, hidden_dim=hidden_dim, ac_fn=ac_fn, + norm_fn=norm_fn) + for _ in range(len(layer_names))]) + self.names = layer_names + + def forward(self, desc0, desc1): + # desc0s = [] + # desc1s = [] + + for i, (layer, name) in enumerate(zip(self.layers, self.names)): + if name == 'cross': + src0, src1 = desc1, desc0 + else: + src0, src1 = desc0, desc1 + delta0 = layer(desc0, src0) + # prob0 = layer.attn.prob + delta1 = layer(desc1, src1) + # prob1 = layer.attn.prob + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + + # if name == 'cross': + # desc0s.append(desc0) + # desc1s.append(desc1) + return [desc0], [desc1] + + def predict(self, desc0, desc1, n_it=-1): + for i, (layer, name) in enumerate(zip(self.layers, self.names)): + if name == 'cross': + src0, src1 = desc1, desc0 + else: + src0, src1 = desc0, desc1 + delta0 = layer(desc0, src0) + # prob0 = layer.attn.prob + delta1 = layer(desc1, src1) + # prob1 = layer.attn.prob + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + + if name == 'cross' and i == n_it: + break + return [desc0], [desc1] + + +class GM(nn.Module): + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + 'weight_path': None, + } + + required_inputs = [ + 'image0', 'keypoints0', 'scores0', 'descriptors0', + 'image1', 'keypoints1', 'scores1', 'descriptors1', + ] + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + print('gm: ', self.config) + + self.n_layers = self.config['n_layers'] + + self.with_sinkhorn = self.config['with_sinkhorn'] + self.match_threshold = self.config['match_threshold'] + + self.sinkhorn_iterations = self.config['sinkhorn_iterations'] + self.kenc = KeypointEncoder( + self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + self.config['keypoint_encoder'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn']) + self.gnn = AttentionalGNN( + feature_dim=self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + hidden_dim=self.config['hidden_dim'], + layer_names=self.config['GNN_layers'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'], + ) + + self.final_proj = nn.ModuleList([nn.Conv1d( + self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, + kernel_size=1, bias=True) for _ in range(self.n_layers)]) + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + + self.match_net = None # GraphLoss(config=self.config) + + self.self_prob0 = None + self.self_prob1 = None + self.cross_prob0 = None + self.cross_prob1 = None + + self.desc_compressor = None + + def forward_train(self, data): + pass + + def produce_matches(self, data, p=0.2, n_it=-1, **kwargs): + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + scores0, scores1 = data['scores0'], data['scores1'] + if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints + shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] + return { + 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0], + 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0], + 'matching_scores0': kpts0.new_zeros(shape0)[0], + 'matching_scores1': kpts1.new_zeros(shape1)[0], + 'skip_train': True + } + + if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): + norm_kpts0 = data['norm_keypoints0'] + norm_kpts1 = data['norm_keypoints1'] + elif 'image0' in data.keys() and 'image1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape) + norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape) + elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']) + norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + # Keypoint MLP encoder. + enc0, enc1 = self.encode_keypoint(norm_kpts0=norm_kpts0, norm_kpts1=norm_kpts1, scores0=scores0, + scores1=scores1) + + if self.config['descriptor_dim'] > 0: + desc0, desc1 = data['descriptors0'], data['descriptors1'] + desc0 = desc0.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] + desc1 = desc1.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] + with torch.no_grad(): + if desc0.shape[1] != self.config['descriptor_dim']: + desc0 = self.desc_compressor(desc0) + if desc1.shape[1] != self.config['descriptor_dim']: + desc1 = self.desc_compressor(desc1) + desc0 = desc0 + enc0 + desc1 = desc1 + enc1 + else: + desc0 = enc0 + desc1 = enc1 + + desc0s, desc1s = self.gnn.predict(desc0, desc1, n_it=n_it) + + mdescs0 = self.final_proj[n_it](desc0s[-1]) + mdescs1 = self.final_proj[n_it](desc1s[-1]) + dist = torch.einsum('bdn,bdm->bnm', mdescs0, mdescs1) + if self.config['descriptor_dim'] > 0: + dist = dist / self.config['descriptor_dim'] ** .5 + else: + dist = dist / 128 ** .5 + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + + output = { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores1, + } + + return output + + def forward(self, data, mode=0): + if not self.training: + return self.produce_matches(data=data, n_it=-1) + return self.forward_train(data=data) + + def encode_keypoint(self, norm_kpts0, norm_kpts1, scores0, scores1): + return self.kenc(norm_kpts0, scores0), self.kenc(norm_kpts1, scores1) + + def compute_distance(self, desc0, desc1, layer_id=-1): + mdesc0 = self.final_proj[layer_id](desc0) + mdesc1 = self.final_proj[layer_id](desc1) + dist = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) + dist = dist / self.config['descriptor_dim'] ** .5 + return dist + + def compute_score(self, dist, dustbin, iteration): + if self.with_sinkhorn: + score = sink_algorithm(M=dist, dustbin=dustbin, + iteration=iteration) # [nI * nB, N, M] + else: + score = dual_softmax(M=dist, dustbin=dustbin) + return score + + def compute_matches(self, scores, p=0.2): + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + indices0, indices1 = max0.indices, max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = torch.where(mutual0, max0.values, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + # valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > p) + valid1 = mutual1 & valid0.gather(1, indices1) + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + return indices0, indices1, mscores0, mscores1 diff --git a/imcui/third_party/pram/nets/gml.py b/imcui/third_party/pram/nets/gml.py new file mode 100644 index 0000000000000000000000000000000000000000..996de5f01211e0a315f7f9b4ce35d561dfc74b2f --- /dev/null +++ b/imcui/third_party/pram/nets/gml.py @@ -0,0 +1,319 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> gml +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 10:56 +==================================================''' +import torch +from torch import nn +import torch.nn.functional as F +from typing import Callable +from .utils import arange_like, normalize_keypoints +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.backends.cudnn.deterministic = True + +eps = 1e-8 + + +def dual_softmax(M, dustbin): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) + return torch.exp(score) + + +def sinkhorn(M, r, c, iteration): + p = torch.softmax(M, dim=-1) + u = torch.ones_like(r) + v = torch.ones_like(c) + for _ in range(iteration): + u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) + v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) + p = p * u.unsqueeze(-1) * v.unsqueeze(-2) + return p + + +def sink_algorithm(M, dustbin, iteration): + M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) + M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) + r = torch.ones([M.shape[0], M.shape[1] - 1], device=device) + r = torch.cat([r, torch.ones([M.shape[0], 1], device=device) * M.shape[1]], dim=-1) + c = torch.ones([M.shape[0], M.shape[2] - 1], device=device) + c = torch.cat([c, torch.ones([M.shape[0], 1], device=device) * M.shape[2]], dim=-1) + p = sinkhorn(M, r, c, iteration) + return p + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb( + freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, + gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ encode position vector """ + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(3, 32), + nn.LayerNorm(32, elementwise_affine=True), + nn.GELU(), + nn.Linear(32, 64), + nn.LayerNorm(64, elementwise_affine=True), + nn.GELU(), + nn.Linear(64, 128), + nn.LayerNorm(128, elementwise_affine=True), + nn.GELU(), + nn.Linear(128, 256), + ) + + def forward(self, kpts, scores): + inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1] + return self.encoder(torch.cat(inputs, dim=-1)) + + +class Attention(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + s = q.shape[-1] ** -0.5 + attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1) + return torch.einsum('...ij,...jd->...id', attn, v) + + +class SelfMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + + assert feat_dim % num_heads == 0 + self.head_dim = feat_dim // num_heads + self.qkv = nn.Linear(feat_dim, hidden_dim * 3) + self.attn = Attention() + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim) + ) + + def forward_(self, x, encoding=None): + qkv = self.qkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + if encoding is not None: + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + attn = self.attn(q, k, v) + message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2)) + return x + self.mlp(torch.cat([x, message], -1)) + + def forward(self, x0, x1, encoding0=None, encoding1=None): + return self.forward_(x0, encoding0), self.forward_(x1, encoding1) + + +class CrossMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + assert hidden_dim % num_heads == 0 + dim_head = hidden_dim // num_heads + self.scale = dim_head ** -0.5 + self.to_qk = nn.Linear(feat_dim, hidden_dim) + self.to_v = nn.Linear(feat_dim, hidden_dim) + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim), + ) + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward(self, x0, x1): + qk0 = self.to_qk(x0) + qk1 = self.to_qk(x1) + v0 = self.to_v(x0) + v1 = self.to_v(x1) + + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.num_heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1)) + + qk0, qk1 = qk0 * self.scale ** 0.5, qk1 * self.scale ** 0.5 + sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1) + m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0) + + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), + m0, m1) + m0, m1 = self.map_(self.proj, m0, m1) + x0 = x0 + self.mlp(torch.cat([x0, m0], -1)) + x1 = x1 + self.mlp(torch.cat([x1, m1], -1)) + return x0, x1 + + +class GML(nn.Module): + ''' + the architecture of lightglue, but trained with imp + ''' + default_config = { + 'descriptor_dim': 128, + 'hidden_dim': 256, + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], + 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total + 'sinkhorn_iterations': 20, + 'match_threshold': 0.2, + 'with_pose': False, + 'n_layers': 9, + 'n_min_tokens': 256, + 'with_sinkhorn': True, + + 'ac_fn': 'relu', + 'norm_fn': 'bn', + + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + self.n_layers = self.config['n_layers'] + + self.with_sinkhorn = self.config['with_sinkhorn'] + self.match_threshold = self.config['match_threshold'] + self.sinkhorn_iterations = self.config['sinkhorn_iterations'] + + self.input_proj = nn.Linear(self.config['descriptor_dim'], self.config['hidden_dim']) + + self.self_attn = nn.ModuleList( + [SelfMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + self.cross_attn = nn.ModuleList( + [CrossMultiHeadAttention(feat_dim=self.config['hidden_dim'], + hidden_dim=self.config['hidden_dim'], + num_heads=4) for _ in range(self.n_layers)] + ) + + head_dim = self.config['hidden_dim'] // 4 + self.poseenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim) + self.out_proj = nn.ModuleList( + [nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) for _ in range(self.n_layers)] + ) + + bin_score = torch.nn.Parameter(torch.tensor(1.)) + self.register_parameter('bin_score', bin_score) + + def forward(self, data, mode=0): + if not self.training: + return self.produce_matches(data=data) + return self.forward_train(data=data) + + def forward_train(self, data: dict, p=0.2, **kwargs): + pass + + def produce_matches(self, data: dict, p=0.2, **kwargs): + desc0, desc1 = data['descriptors0'], data['descriptors1'] + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + # Keypoint normalization. + if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): + norm_kpts0 = data['norm_keypoints0'] + norm_kpts1 = data['norm_keypoints1'] + elif 'image0' in data.keys() and 'image1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape).float() + norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape).float() + elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): + norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']).float() + norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']).float() + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + enc0 = self.poseenc(norm_kpts0) + enc1 = self.poseenc(norm_kpts1) + + nI = self.n_layers + # nI = 5 + + for i in range(nI): + desc0, desc1 = self.self_attn[i](desc0, desc1, enc0, enc1) + desc0, desc1 = self.cross_attn[i](desc0, desc1) + + d = desc0.shape[-1] + mdesc0 = self.out_proj[nI - 1](desc0) / d ** .25 + mdesc1 = self.out_proj[nI - 1](desc1) / d ** .25 + + dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + + score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) + indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) + + output = { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores1, + } + + return output + + def compute_score(self, dist, dustbin, iteration): + if self.with_sinkhorn: + score = sink_algorithm(M=dist, dustbin=dustbin, + iteration=iteration) # [nI * nB, N, M] + else: + score = dual_softmax(M=dist, dustbin=dustbin) + return score + + def compute_matches(self, scores, p=0.2): + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + indices0, indices1 = max0.indices, max1.indices + mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) + mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) + zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = torch.where(mutual0, max0.values, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + # valid0 = mutual0 & (mscores0 > self.config['match_threshold']) + valid0 = mutual0 & (mscores0 > p) + valid1 = mutual1 & valid0.gather(1, indices1) + indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + + return indices0, indices1, mscores0, mscores1 diff --git a/imcui/third_party/pram/nets/layers.py b/imcui/third_party/pram/nets/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..417488e6a163327895eb435567c4255c7827bca2 --- /dev/null +++ b/imcui/third_party/pram/nets/layers.py @@ -0,0 +1,109 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> layers +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:46 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from einops import rearrange + + +def MLP(channels: list, do_bn=True, ac_fn='relu', norm_fn='bn'): + """ Multi-layer perceptron """ + n = len(channels) + layers = [] + for i in range(1, n): + layers.append( + nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) + if i < (n - 1): + if norm_fn == 'in': + layers.append(nn.InstanceNorm1d(channels[i], eps=1e-3)) + elif norm_fn == 'bn': + layers.append(nn.BatchNorm1d(channels[i], eps=1e-3)) + if ac_fn == 'relu': + layers.append(nn.ReLU()) + elif ac_fn == 'gelu': + layers.append(nn.GELU()) + elif ac_fn == 'lrelu': + layers.append(nn.LeakyReLU(negative_slope=0.1)) + # if norm_fn == 'ln': + # layers.append(nn.LayerNorm(channels[i])) + return nn.Sequential(*layers) + + +class MultiHeadedAttention(nn.Module): + def __init__(self, num_heads: int, d_model: int): + super().__init__() + assert d_model % num_heads == 0 + self.dim = d_model // num_heads + self.num_heads = num_heads + self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) + self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) + + def forward(self, query, key, value, M=None): + ''' + :param query: [B, D, N] + :param key: [B, D, M] + :param value: [B, D, M] + :param M: [B, N, M] + :return: + ''' + + batch_dim = query.size(0) + query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) + for l, x in zip(self.proj, (query, key, value))] # [B, D, NH, N] + dim = query.shape[1] + scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 + + if M is not None: + # print('M: ', scores.shape, M.shape, torch.sum(M, dim=2)) + # scores = scores * M[:, None, :, :].expand_as(scores) + # with torch.no_grad(): + mask = (1 - M[:, None, :, :]).repeat(1, scores.shape[1], 1, 1).bool() # [B, H, N, M] + scores = scores.masked_fill(mask, -torch.finfo(scores.dtype).max) + prob = F.softmax(scores, dim=-1) # * (~mask).float() # * mask.float() + else: + prob = F.softmax(scores, dim=-1) + + x = torch.einsum('bhnm,bdhm->bdhn', prob, value) + self.prob = prob + + out = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1)) + + return out + + +class AttentionalPropagation(nn.Module): + def __init__(self, feature_dim: int, num_heads: int, ac_fn='relu', norm_fn='bn'): + super().__init__() + self.attn = MultiHeadedAttention(num_heads, feature_dim) + self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) + nn.init.constant_(self.mlp[-1].bias, 0.0) + + def forward(self, x, source, M=None): + message = self.attn(x, source, source, M=M) + self.prob = self.attn.prob + + out = self.mlp(torch.cat([x, message], dim=1)) + return out + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self, input_dim, feature_dim, layers, ac_fn='relu', norm_fn='bn'): + super().__init__() + self.input_dim = input_dim + self.encoder = MLP([input_dim] + layers + [feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) + nn.init.constant_(self.encoder[-1].bias, 0.0) + + def forward(self, kpts, scores=None): + if self.input_dim == 2: + return self.encoder(kpts.transpose(1, 2)) + else: + inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] # [B, 2, N] + [B, 1, N] + return self.encoder(torch.cat(inputs, dim=1)) diff --git a/imcui/third_party/pram/nets/load_segnet.py b/imcui/third_party/pram/nets/load_segnet.py new file mode 100644 index 0000000000000000000000000000000000000000..51b8c5bc3fc1c25a8e52dd21cc6f3f4e79b418aa --- /dev/null +++ b/imcui/third_party/pram/nets/load_segnet.py @@ -0,0 +1,31 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> load_segnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 09/04/2024 15:39 +==================================================''' +from nets.segnet import SegNet +from nets.segnetvit import SegNetViT + + +def load_segnet(network, n_class, desc_dim, n_layers, output_dim): + model_config = { + 'network': { + 'descriptor_dim': desc_dim, + 'n_layers': n_layers, + 'n_class': n_class, + 'output_dim': output_dim, + 'with_score': False, + } + } + + if network == 'segnet': + model = SegNet(model_config.get('network', {})) + # config['with_cls'] = False + elif network == 'segnetvit': + model = SegNetViT(model_config.get('network', {})) + else: + raise 'ERROR! {:s} model does not exist'.format(config['network']) + + return model diff --git a/imcui/third_party/pram/nets/retnet.py b/imcui/third_party/pram/nets/retnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f3346fcd82193683ec72d0e55a2429d18a974b --- /dev/null +++ b/imcui/third_party/pram/nets/retnet.py @@ -0,0 +1,174 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> retnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 22/02/2024 15:23 +==================================================''' +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File glretrieve -> retnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 15/02/2024 10:55 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class ResBlock(nn.Module): + def __init__(self, inplanes, outplanes, stride=1, groups=32, dilation=1, norm_layer=None, ac_fn=None): + super(ResBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = conv1x1(inplanes, outplanes) + self.bn1 = norm_layer(outplanes) + self.conv2 = conv3x3(outplanes, outplanes, stride, groups, dilation) + self.bn2 = norm_layer(outplanes) + self.conv3 = conv1x1(outplanes, outplanes) + self.bn3 = norm_layer(outplanes) + if ac_fn is None: + self.ac_fn = nn.ReLU(inplace=True) + else: + self.ac_fn = ac_fn + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.ac_fn(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.ac_fn(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += identity + out = self.ac_fn(out) + + return out + + +class GeneralizedMeanPooling(nn.Module): + r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. + The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` + - At p = infinity, one gets Max Pooling + - At p = 1, one gets Average Pooling + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + """ + + def __init__(self, norm, output_size=1, eps=1e-6): + super(GeneralizedMeanPooling, self).__init__() + assert norm > 0 + self.p = float(norm) + self.output_size = output_size + self.eps = eps + + def forward(self, x): + x = x.clamp(min=self.eps).pow(self.p) + return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + str(self.p) + ', ' \ + + 'output_size=' + str(self.output_size) + ')' + + +class GeneralizedMeanPoolingP(GeneralizedMeanPooling): + """ Same, but norm is trainable + """ + + def __init__(self, norm=3, output_size=1, eps=1e-6): + super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) + self.p = nn.Parameter(torch.ones(1) * norm) + + +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +class L2Norm(nn.Module): + def __init__(self, dim=1): + super().__init__() + self.dim = dim + + def forward(self, input): + return F.normalize(input, p=2, dim=self.dim) + + +class RetNet(nn.Module): + def __init__(self, indim=256, outdim=1024): + super().__init__() + + ac_fn = nn.GELU() + + self.convs = nn.Sequential( + # no batch normalization + + nn.Conv2d(in_channels=indim, out_channels=512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), + # nn.ReLU(), + + ResBlock(512, 512, groups=32, stride=1, ac_fn=ac_fn), + ResBlock(512, 512, groups=32, stride=1, ac_fn=ac_fn), + + nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(1024), + # nn.ReLU(), + ResBlock(inplanes=1024, outplanes=1024, groups=32, stride=1, ac_fn=ac_fn), + ResBlock(inplanes=1024, outplanes=1024, groups=32, stride=1, ac_fn=ac_fn), + ) + + self.pool = GeneralizedMeanPoolingP() + self.fc = nn.Linear(1024, out_features=outdim) + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + out = self.convs(x) + out = self.pool(out).reshape(x.shape[0], -1) + out = self.fc(out) + out = F.normalize(out, p=2, dim=1) + return out + + +if __name__ == '__main__': + mode = RetNet(indim=256, outdim=1024) + state_dict = mode.state_dict() + keys = state_dict.keys() + print(keys) + shapes = [state_dict[v].shape for v in keys] + print(shapes) diff --git a/imcui/third_party/pram/nets/segnet.py b/imcui/third_party/pram/nets/segnet.py new file mode 100644 index 0000000000000000000000000000000000000000..632a38cb83ca77a23b5c1e1276996bd5574c3a0b --- /dev/null +++ b/imcui/third_party/pram/nets/segnet.py @@ -0,0 +1,120 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> segnet +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:46 +==================================================''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from nets.layers import MLP, KeypointEncoder +from nets.layers import AttentionalPropagation +from nets.utils import normalize_keypoints + + +class SegGNN(nn.Module): + def __init__(self, feature_dim: int, n_layers: int, ac_fn: str = 'relu', norm_fn: str = 'bn', **kwargs): + super().__init__() + self.layers = nn.ModuleList([ + AttentionalPropagation(feature_dim, 4, ac_fn=ac_fn, norm_fn=norm_fn) + for _ in range(n_layers) + ]) + + def forward(self, desc): + for i, layer in enumerate(self.layers): + delta = layer(desc, desc) + desc = desc + delta + + return desc + + +class SegNet(nn.Module): + default_config = { + 'descriptor_dim': 256, + 'output_dim': 1024, + 'n_class': 512, + 'keypoint_encoder': [32, 64, 128, 256], + 'n_layers': 9, + 'ac_fn': 'relu', + 'norm_fn': 'in', + 'with_score': False, + # 'with_global': False, + 'with_cls': False, + 'with_sc': False, + } + + def __init__(self, config={}): + super().__init__() + self.config = {**self.default_config, **config} + self.with_cls = self.config['with_cls'] + self.with_sc = self.config['with_sc'] + + self.n_layers = self.config['n_layers'] + self.gnn = SegGNN( + feature_dim=self.config['descriptor_dim'], + n_layers=self.config['n_layers'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'], + ) + + self.with_score = self.config['with_score'] + self.kenc = KeypointEncoder( + input_dim=3 if self.with_score else 2, + feature_dim=self.config['descriptor_dim'], + layers=self.config['keypoint_encoder'], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'] + ) + + self.seg = MLP(channels=[self.config['descriptor_dim'], + self.config['output_dim'], + self.config['n_class']], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'] + ) + + if self.with_sc: + self.sc = MLP(channels=[self.config['descriptor_dim'], + self.config['output_dim'], + 3], + ac_fn=self.config['ac_fn'], + norm_fn=self.config['norm_fn'] + ) + + def preprocess(self, data): + desc0 = data['seg_descriptors'] + desc0 = desc0.transpose(1, 2) # [B, N, D] - > [B, D, N] + + if 'norm_keypoints' in data.keys(): + norm_kpts0 = data['norm_keypoints'] + elif 'image' in data.keys(): + kpts0 = data['keypoints'] + norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + # Keypoint MLP encoder. + if self.with_score: + scores0 = data['scores'] + else: + scores0 = None + enc0 = self.kenc(norm_kpts0, scores0) + + return desc0, enc0 + + def forward(self, data): + desc, enc = self.preprocess(data=data) + desc = desc + enc + + desc = self.gnn(desc) + cls_output = self.seg(desc) # [B, C, N] + output = { + 'prediction': cls_output.transpose(-1, -2).contiguous(), + } + + if self.with_sc: + sc_output = self.sc(desc) + output['sc'] = sc_output + + return output diff --git a/imcui/third_party/pram/nets/segnetvit.py b/imcui/third_party/pram/nets/segnetvit.py new file mode 100644 index 0000000000000000000000000000000000000000..7919b545c26d3098df84d2e8e909d7ed69809dcd --- /dev/null +++ b/imcui/third_party/pram/nets/segnetvit.py @@ -0,0 +1,203 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> segnetvit +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 14:52 +==================================================''' + +import torch +from torch import nn +import torch.nn.functional as F +from nets.utils import normalize_keypoints + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb( + freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, + gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ encode position vector """ + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(2, 32), + nn.LayerNorm(32, elementwise_affine=True), + nn.GELU(), + nn.Linear(32, 64), + nn.LayerNorm(64, elementwise_affine=True), + nn.GELU(), + nn.Linear(64, 128), + nn.LayerNorm(128, elementwise_affine=True), + nn.GELU(), + nn.Linear(128, 256), + ) + + def forward(self, kpts, scores=None): + if scores is not None: + inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1] + return self.encoder(torch.cat(inputs, dim=-1)) + else: + return self.encoder(kpts) + + +class Attention(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + s = q.shape[-1] ** -0.5 + attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1) + return torch.einsum('...ij,...jd->...id', attn, v) + + +class SelfMultiHeadAttention(nn.Module): + def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + + assert feat_dim % num_heads == 0 + self.head_dim = feat_dim // num_heads + self.qkv = nn.Linear(feat_dim, hidden_dim * 3) + self.attn = Attention() + self.proj = nn.Linear(hidden_dim, hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(feat_dim + hidden_dim, feat_dim * 2), + nn.LayerNorm(feat_dim * 2, elementwise_affine=True), + nn.GELU(), + nn.Linear(feat_dim * 2, feat_dim) + ) + + def forward(self, x, encoding=None): + qkv = self.qkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + if encoding is not None: + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + attn = self.attn(q, k, v) + message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2)) + return x + self.mlp(torch.cat([x, message], -1)) + + +class SegGNNViT(nn.Module): + def __init__(self, feature_dim: int, n_layers: int, hidden_dim: int = 256, num_heads: int = 4, **kwargs): + super(SegGNNViT, self).__init__() + self.layers = nn.ModuleList([ + SelfMultiHeadAttention(feat_dim=feature_dim, hidden_dim=hidden_dim, num_heads=num_heads) + for _ in range(n_layers) + ]) + + def forward(self, desc, encoding=None): + for i, layer in enumerate(self.layers): + desc = layer(desc, encoding) + # desc = desc + delta // should be removed as this is already done in self-attention + return desc + + +class SegNetViT(nn.Module): + default_config = { + 'descriptor_dim': 256, + 'output_dim': 1024, + 'n_class': 512, + 'keypoint_encoder': [32, 64, 128, 256], + 'n_layers': 15, + 'num_heads': 4, + 'hidden_dim': 256, + 'with_score': False, + 'with_global': False, + 'with_cls': False, + 'with_sc': False, + } + + def __init__(self, config={}): + super(SegNetViT, self).__init__() + self.config = {**self.default_config, **config} + self.with_cls = self.config['with_cls'] + self.with_sc = self.config['with_sc'] + + self.n_layers = self.config['n_layers'] + self.gnn = SegGNNViT( + feature_dim=self.config['hidden_dim'], + n_layers=self.config['n_layers'], + hidden_dim=self.config['hidden_dim'], + num_heads=self.config['num_heads'], + ) + + self.with_score = self.config['with_score'] + self.kenc = LearnableFourierPositionalEncoding(2, self.config['hidden_dim'] // self.config['num_heads'], + self.config['hidden_dim'] // self.config['num_heads']) + + self.input_proj = nn.Linear(in_features=self.config['descriptor_dim'], + out_features=self.config['hidden_dim']) + self.seg = nn.Sequential( + nn.Linear(in_features=self.config['hidden_dim'], out_features=self.config['output_dim']), + nn.LayerNorm(self.config['output_dim'], elementwise_affine=True), + nn.GELU(), + nn.Linear(self.config['output_dim'], self.config['n_class']) + ) + + if self.with_sc: + self.sc = nn.Sequential( + nn.Linear(in_features=config['hidden_dim'], out_features=self.config['output_dim']), + nn.LayerNorm(self.config['output_dim'], elementwise_affine=True), + nn.GELU(), + nn.Linear(self.config['output_dim'], 3) + ) + + def preprocess(self, data): + desc0 = data['seg_descriptors'] + if 'norm_keypoints' in data.keys(): + norm_kpts0 = data['norm_keypoints'] + elif 'image' in data.keys(): + kpts0 = data['keypoints'] + norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape) + else: + raise ValueError('Require image shape for keypoint coordinate normalization') + + enc0 = self.kenc(norm_kpts0) + + return desc0, enc0 + + def forward(self, data): + desc, enc = self.preprocess(data=data) + desc = self.input_proj(desc) + + desc = self.gnn(desc, enc) + seg_output = self.seg(desc) # [B, N, C] + + output = { + 'prediction': seg_output, + } + + if self.with_sc: + sc_output = self.sc(desc) + output['sc'] = sc_output + + return output diff --git a/imcui/third_party/pram/nets/sfd2.py b/imcui/third_party/pram/nets/sfd2.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c5a099b001ed9cf9e8a82b1b77dc9f7d9e31c8 --- /dev/null +++ b/imcui/third_party/pram/nets/sfd2.py @@ -0,0 +1,596 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> sfd2 +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 14:53 +==================================================''' +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import torchvision.transforms as tvf + +RGB_mean = [0.485, 0.456, 0.406] +RGB_std = [0.229, 0.224, 0.225] + +norm_RGB = tvf.Compose([tvf.Normalize(mean=RGB_mean, std=RGB_std)]) + + +def simple_nms(scores, nms_radius: int): + """ Fast Non-maximum suppression to remove nearby points """ + assert (nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def remove_borders(keypoints, scores, border: int, height: int, width: int): + """ Removes keypoints too close to the border """ + mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) + mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints, scores, k: int): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s: int = 8): + """ Interpolate descriptors at keypoint locations """ + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(keypoints)[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', align_corners=True) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + return descriptors + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=False, groups=1, dilation=1): + if not use_bn: + return nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation), + nn.ReLU(inplace=True), + ) + else: + return nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + +class ResBlock(nn.Module): + def __init__(self, inplanes, outplanes, stride=1, groups=32, dilation=1, norm_layer=None): + super(ResBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = conv1x1(inplanes, outplanes) + self.bn1 = norm_layer(outplanes) + self.conv2 = conv3x3(outplanes, outplanes, stride, groups, dilation) + self.bn2 = norm_layer(outplanes) + self.conv3 = conv1x1(outplanes, outplanes) + self.bn3 = norm_layer(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += identity + out = self.relu(out) + + return out + + +class ResNet4x(nn.Module): + default_config = { + 'conf_th': 0.005, + 'remove_borders': 4, + 'min_keypoints': 128, + 'max_keypoints': 4096, + } + + def __init__(self, inputdim=3, outdim=128, desc_compressor=None): + super().__init__() + self.outdim = outdim + self.desc_compressor = desc_compressor + + d1, d2, d3, d4, d5, d6 = 64, 128, 256, 256, 256, 256 + self.conv1a = conv(in_channels=inputdim, out_channels=d1, kernel_size=3, use_bn=True) + self.conv1b = conv(in_channels=d1, out_channels=d1, kernel_size=3, stride=2, use_bn=True) + + self.conv2a = conv(in_channels=d1, out_channels=d2, kernel_size=3, use_bn=True) + self.conv2b = conv(in_channels=d2, out_channels=d2, kernel_size=3, stride=2, use_bn=True) + + self.conv3a = conv(in_channels=d2, out_channels=d3, kernel_size=3, use_bn=True) + self.conv3b = conv(in_channels=d3, out_channels=d3, kernel_size=3, use_bn=True) + + self.conv4 = nn.Sequential( + ResBlock(inplanes=256, outplanes=256, groups=32), + ResBlock(inplanes=256, outplanes=256, groups=32), + ResBlock(inplanes=256, outplanes=256, groups=32), + ) + + self.convPa = nn.Sequential( + torch.nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + ) + self.convDa = nn.Sequential( + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + ) + + self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0) + self.convDb = torch.nn.Conv2d(256, outdim, kernel_size=1, stride=1, padding=0) + + def det(self, x): + out1a = self.conv1a(x) + out1b = self.conv1b(out1a) + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) + + out4 = self.conv4(out3b) + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + + # Descriptor Head + cDa = self.convDa(out4) + desc = self.convDb(cDa) + desc = F.normalize(desc, dim=1) + + return score, desc + + def forward(self, batch): + out1a = self.conv1a(batch['image']) + out1b = self.conv1b(out1a) + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) + + out4 = self.conv4(out3b) + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + + # Descriptor Head + cDa = self.convDa(out4) + desc = self.convDb(cDa) + desc = F.normalize(desc, dim=1) + + return { + 'dense_features': desc, + 'scores': score, + 'logits': logits, + 'semi_map': semi, + } + + def extract_patches(self, batch): + out1a = self.conv1a(batch['image']) + out1b = self.conv1b(out1a) + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) + + out4 = self.conv4(out3b) + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + + # Descriptor Head + cDa = self.convDa(out4) + desc = self.convDb(cDa) + desc = F.normalize(desc, dim=1) + + return { + 'dense_features': desc, + 'scores': score, + 'logits': logits, + 'semi_map': semi, + } + + def extract_local_global(self, data, + config={ + 'conf_th': 0.005, + 'remove_borders': 4, + 'min_keypoints': 128, + 'max_keypoints': 4096, + } + ): + + config = {**self.default_config, **config} + + b, ic, ih, iw = data['image'].shape + out1a = self.conv1a(data['image']) + out1b = self.conv1b(out1a) # 64 + + out2a = self.conv2a(out1b) + out2b = self.conv2b(out2a) # 128 + + out3a = self.conv3a(out2b) + out3b = self.conv3b(out3a) # 256 + + out4 = self.conv4(out3b) # 256 + + cPa = self.convPa(out4) + logits = self.convPb(cPa) + full_semi = torch.softmax(logits, dim=1) + semi = full_semi[:, :-1, :, :] + Hc, Wc = semi.size(2), semi.size(3) + score = semi.permute([0, 2, 3, 1]) + score = score.view(score.size(0), Hc, Wc, 8, 8) + score = score.permute([0, 1, 3, 2, 4]) + score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8) + if Hc * 8 != ih or Wc * 8 != iw: + score = F.interpolate(score.unsqueeze(1), size=[ih, iw], align_corners=True, mode='bilinear') + score = score.squeeze(1) + # extract keypoints + nms_scores = simple_nms(scores=score, nms_radius=4) + keypoints = [ + torch.nonzero(s >= config['conf_th']) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + if len(scores[0]) <= config['min_keypoints']: + keypoints = [ + torch.nonzero(s >= config['conf_th'] * 0.5) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, config['remove_borders'], ih, iw) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with highest score + if config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + # Descriptor Head + cDa = self.convDa(out4) + desc_map = self.convDb(cDa) + desc_map = F.normalize(desc_map, dim=1) + + descriptors = [sample_descriptors(k[None], d[None], 4)[0] + for k, d in zip(keypoints, desc_map)] + + return { + 'score_map': score, + 'desc_map': desc_map, + 'mid_features': out4, + 'global_descriptors': [out1b, out2b, out3b, out4], + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + } + + def sample(self, score_map, semi_descs, kpts, s=4, norm_desc=True): + # print('sample: ', score_map.shape, semi_descs.shape, kpts.shape) + b, c, h, w = semi_descs.shape + norm_kpts = kpts - s / 2 + 0.5 + norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(norm_kpts)[None] + norm_kpts = norm_kpts * 2 - 1 + # args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + descriptors = torch.nn.functional.grid_sample( + semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True) + + if norm_desc: + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + else: + descriptors = descriptors.reshape(b, c, -1) + + # print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()), + # torch.max(kpts[:, 0].long())) + scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()] + + return scores, descriptors.squeeze(0) + + +class DescriptorCompressor(nn.Module): + def __init__(self, inputdim: int, outdim: int): + super().__init__() + self.inputdim = inputdim + self.outdim = outdim + self.conv = nn.Conv1d(in_channels=inputdim, out_channels=outdim, kernel_size=1, padding=0, bias=True) + + def forward(self, x): + # b, c, n = x.shape + out = self.conv(x) + out = F.normalize(out, p=2, dim=1) + return out + + +def extract_sfd2_return(model, img, conf_th=0.001, + mask=None, + topK=-1, + min_keypoints=0, + **kwargs): + old_bm = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = False # speedup + + img = norm_RGB(img.squeeze()) + img = img[None] + img = img.cuda() + + B, one, H, W = img.shape + + all_pts = [] + all_descs = [] + + if 'scales' in kwargs.keys(): + scales = kwargs.get('scales') + else: + scales = [1.0] + + for s in scales: + if s == 1.0: + new_img = img + else: + nh = int(H * s) + nw = int(W * s) + new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True) + nh, nw = new_img.shape[2:] + + with torch.no_grad(): + heatmap, coarse_desc = model.det(new_img) + + # print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape) + if len(heatmap.size()) == 3: + heatmap = heatmap.unsqueeze(1) + if len(heatmap.size()) == 2: + heatmap = heatmap.unsqueeze(0) + heatmap = heatmap.unsqueeze(1) + # print(heatmap.shape) + if heatmap.size(2) != nh or heatmap.size(3) != nw: + heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True) + + conf_thresh = conf_th + nms_dist = 3 + border_remove = 4 + scores = simple_nms(heatmap, nms_radius=nms_dist) + keypoints = [ + torch.nonzero(s > conf_thresh) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + # print('scores in return: ', len(scores[0])) + + # print(keypoints[0].shape) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + scores = scores[0].data.cpu().numpy().squeeze() + keypoints = keypoints[0].data.cpu().numpy().squeeze() + pts = keypoints.transpose() + pts[2, :] = scores + + inds = np.argsort(pts[2, :]) + pts = pts[:, inds[::-1]] # Sort by confidence. + # Remove points along border. + bord = border_remove + toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord)) + toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord)) + toremove = np.logical_or(toremoveW, toremoveH) + pts = pts[:, ~toremove] + + # valid_idex = heatmap > conf_thresh + # valid_score = heatmap[valid_idex] + # """ + # --- Process descriptor. + # coarse_desc = coarse_desc.data.cpu().numpy().squeeze() + D = coarse_desc.size(1) + if pts.shape[1] == 0: + desc = np.zeros((D, 0)) + else: + if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw: + desc = coarse_desc[:, :, pts[1, :], pts[0, :]] + desc = desc.data.cpu().numpy().reshape(D, -1) + else: + # Interpolate into descriptor map using 2D point locations. + samp_pts = torch.from_numpy(pts[:2, :].copy()) + samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1. + samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1. + samp_pts = samp_pts.transpose(0, 1).contiguous() + samp_pts = samp_pts.view(1, 1, -1, 2) + samp_pts = samp_pts.float() + samp_pts = samp_pts.cuda() + desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True) + desc = desc.data.cpu().numpy().reshape(D, -1) + desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :] + + if pts.shape[1] == 0: + continue + + # print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H) + pts[0, :] = pts[0, :] * W / nw + pts[1, :] = pts[1, :] * H / nh + all_pts.append(np.transpose(pts, [1, 0])) + all_descs.append(np.transpose(desc, [1, 0])) + + all_pts = np.vstack(all_pts) + all_descs = np.vstack(all_descs) + + torch.backends.cudnn.benchmark = old_bm + + if all_pts.shape[0] == 0: + return None, None, None + + keypoints = all_pts[:, 0:2] + scores = all_pts[:, 2] + descriptors = all_descs + + if mask is not None: + # cv2.imshow("mask", mask) + # cv2.waitKey(0) + labels = [] + others = [] + keypoints_with_labels = [] + scores_with_labels = [] + descriptors_with_labels = [] + keypoints_without_labels = [] + scores_without_labels = [] + descriptors_without_labels = [] + + id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0]) + # print(img.shape, id_img.shape) + + for i in range(keypoints.shape[0]): + x = keypoints[i, 0] + y = keypoints[i, 1] + # print("x-y", x, y, int(x), int(y)) + gid = id_img[int(y), int(x)] + if gid == 0: + keypoints_without_labels.append(keypoints[i]) + scores_without_labels.append(scores[i]) + descriptors_without_labels.append(descriptors[i]) + others.append(0) + else: + keypoints_with_labels.append(keypoints[i]) + scores_with_labels.append(scores[i]) + descriptors_with_labels.append(descriptors[i]) + labels.append(gid) + + if topK > 0: + if topK <= len(keypoints_with_labels): + idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK] + keypoints = np.array(keypoints_with_labels, float)[idxes] + scores = np.array(scores_with_labels, float)[idxes] + labels = np.array(labels, np.int32)[idxes] + descriptors = np.array(descriptors_with_labels, float)[idxes] + elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels): + # keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels]) + # scores = np.vstack([scorescc_with_labels, scores_without_labels]) + # descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels]) + # labels = np.vstack([labels, others]) + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in range(len(others)): + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + else: + n = topK - len(keypoints_with_labels) + idxes = np.array(scores_without_labels, float).argsort()[::-1][:n] + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in idxes: + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + keypoints = np.array(keypoints, float) + descriptors = np.array(descriptors, float) + # print(keypoints.shape, descriptors.shape) + return {"keypoints": np.array(keypoints, float), + "descriptors": np.array(descriptors, float), + "scores": np.array(scores, np.float), + "labels": np.array(labels, np.int32), + } + else: + # print(topK) + if topK > 0: + idxes = np.array(scores, dtype=float).argsort()[::-1][:topK] + keypoints = np.array(keypoints[idxes], dtype=float) + scores = np.array(scores[idxes], dtype=float) + descriptors = np.array(descriptors[idxes], dtype=float) + + keypoints = np.array(keypoints, dtype=float) + scores = np.array(scores, dtype=float) + descriptors = np.array(descriptors, dtype=float) + + # print(keypoints.shape, descriptors.shape) + + return {"keypoints": np.array(keypoints, dtype=float), + "descriptors": descriptors, + "scores": scores, + } + + +def load_sfd2(weight_path): + net = ResNet4x(inputdim=3, outdim=128) + net.load_state_dict(torch.load(weight_path, map_location='cpu')['state_dict'], strict=True) + # print('Load sfd2 from {:s}'.format(weight_path)) + return net diff --git a/imcui/third_party/pram/nets/superpoint.py b/imcui/third_party/pram/nets/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6751016bd71cbbbb072243b3c1aebc100f632693 --- /dev/null +++ b/imcui/third_party/pram/nets/superpoint.py @@ -0,0 +1,607 @@ +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +from pathlib import Path +import torch +from torch import nn +import numpy as np +import cv2 +import torch.nn.functional as F + + +def simple_nms(scores, nms_radius: int): + """ Fast Non-maximum suppression to remove nearby points """ + assert (nms_radius >= 0) + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def remove_borders(keypoints, scores, border: int, height: int, width: int): + """ Removes keypoints too close to the border """ + mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) + mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) + mask = mask_h & mask_w + return keypoints[mask], scores[mask] + + +def top_k_keypoints(keypoints, scores, k: int): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0) + return keypoints[indices], scores + + +def sample_descriptors(keypoints, descriptors, s: int = 8): + """ Interpolate descriptors at keypoint locations """ + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(keypoints)[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + return descriptors + + +class SuperPoint(nn.Module): + """SuperPoint Convolutional Detector and Descriptor + + SuperPoint: Self-Supervised Interest Point Detection and + Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew + Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 + + """ + default_config = { + 'descriptor_dim': 256, + 'nms_radius': 3, + 'keypoint_threshold': 0.001, + 'max_keypoints': -1, + 'min_keypoints': 32, + 'remove_borders': 4, + } + + def __init__(self, config): + super().__init__() + self.config = {**self.default_config, **config} + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) # 64 + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) # 64 + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) # 128 + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) # 128 + + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256 + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256 + self.convDb = nn.Conv2d( + c5, self.config['descriptor_dim'], + kernel_size=1, stride=1, padding=0) + + # path = Path(__file__).parent / 'weights/superpoint_v1.pth' + path = config['weight_path'] + self.load_state_dict(torch.load(str(path), map_location='cpu'), strict=True) + + mk = self.config['max_keypoints'] + if mk == 0 or mk < -1: + raise ValueError('\"max_keypoints\" must be positive or \"-1\"') + + print('Loaded SuperPoint model') + + def extract_global(self, data): + # Shared Encoder + x0 = self.relu(self.conv1a(data['image'])) + x0 = self.relu(self.conv1b(x0)) + x0 = self.pool(x0) + x1 = self.relu(self.conv2a(x0)) + x1 = self.relu(self.conv2b(x1)) + x1 = self.pool(x1) + x2 = self.relu(self.conv3a(x1)) + x2 = self.relu(self.conv3b(x2)) + x2 = self.pool(x2) + x3 = self.relu(self.conv4a(x2)) + x3 = self.relu(self.conv4b(x3)) + + x4 = self.relu(self.convDa(x3)) + + # print('ex_g: ', x0.shape, x1.shape, x2.shape, x3.shape, x4.shape) + + return [x0, x1, x2, x3, x4] + + def extract_local_global(self, data): + # Shared Encoder + b, ic, ih, iw = data['image'].shape + x0 = self.relu(self.conv1a(data['image'])) + x0 = self.relu(self.conv1b(x0)) + x0 = self.pool(x0) + x1 = self.relu(self.conv2a(x0)) + x1 = self.relu(self.conv2b(x1)) + x1 = self.pool(x1) + x2 = self.relu(self.conv3a(x1)) + x2 = self.relu(self.conv3b(x2)) + x2 = self.pool(x2) + x3 = self.relu(self.conv4a(x2)) + x3 = self.relu(self.conv4b(x3)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x3)) + score = self.convPb(cPa) + score = torch.nn.functional.softmax(score, 1)[:, :-1] + # print(scores.shape) + b, _, h, w = score.shape + score = score.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + score = score.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + score = torch.nn.functional.interpolate(score.unsqueeze(1), size=(ih, iw), align_corners=True, + mode='bilinear') + score = score.squeeze(1) + + # extract kpts + nms_scores = simple_nms(scores=score, nms_radius=self.config['nms_radius']) + keypoints = [ + torch.nonzero(s >= self.config['keypoint_threshold']) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + if len(scores[0]) <= self.config['min_keypoints']: + keypoints = [ + torch.nonzero(s >= self.config['keypoint_threshold'] * 0.5) + for s in nms_scores] + scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, self.config['remove_borders'], ih, iw) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with the highest score + if self.config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, self.config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x3)) + desc_map = self.convDb(cDa) + desc_map = torch.nn.functional.normalize(desc_map, p=2, dim=1) + descriptors = [sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, desc_map)] + + return { + 'score_map': score, + 'desc_map': desc_map, + 'mid_features': cDa, # 256 + 'global_descriptors': [x0, x1, x2, x3, cDa], + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + } + + def sample(self, score_map, semi_descs, kpts, s=8, norm_desc=True): + # print('sample: ', score_map.shape, semi_descs.shape, kpts.shape) + b, c, h, w = semi_descs.shape + norm_kpts = kpts - s / 2 + 0.5 + norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to(norm_kpts)[None] + norm_kpts = norm_kpts * 2 - 1 + # args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} + descriptors = torch.nn.functional.grid_sample( + semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True) + if norm_desc: + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1) + else: + descriptors = descriptors.reshape(b, c, -1) + + # print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()), + # torch.max(kpts[:, 0].long())) + scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()] + + return scores, descriptors.squeeze(0) + + def extract(self, data): + """ Compute keypoints, scores, descriptors for image """ + # Shared Encoder + x = self.relu(self.conv1a(data['image'])) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + return scores, descriptors + + def det(self, image): + """ Compute keypoints, scores, descriptors for image """ + # Shared Encoder + x = self.relu(self.conv1a(image)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + # print(scores.shape) + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + return scores, descriptors + + def forward(self, data): + """ Compute keypoints, scores, descriptors for image """ + # Shared Encoder + x = self.relu(self.conv1a(data['image'])) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + # print(scores.shape) + b, _, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + scores = simple_nms(scores, self.config['nms_radius']) + + # Extract keypoints + keypoints = [ + torch.nonzero(s > self.config['keypoint_threshold']) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + + # Discard keypoints near the image borders + keypoints, scores = list(zip(*[ + remove_borders(k, s, self.config['remove_borders'], h * 8, w * 8) + for k, s in zip(keypoints, scores)])) + + # Keep the k keypoints with highest score + if self.config['max_keypoints'] >= 0: + keypoints, scores = list(zip(*[ + top_k_keypoints(k, s, self.config['max_keypoints']) + for k, s in zip(keypoints, scores)])) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) + + # Extract descriptors + # print(keypoints[0].shape) + descriptors = [sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, descriptors)] + + return { + 'keypoints': keypoints, + 'scores': scores, + 'descriptors': descriptors, + 'global_descriptor': x, + } + + +def extract_descriptor(sample_pts, coarse_desc, H, W): + ''' + :param samplt_pts: + :param coarse_desc: + :return: + ''' + with torch.no_grad(): + norm_sample_pts = torch.zeros_like(sample_pts) + norm_sample_pts[0, :] = (sample_pts[0, :] / (float(W) / 2.)) - 1. # x + norm_sample_pts[1, :] = (sample_pts[1, :] / (float(H) / 2.)) - 1. # y + norm_sample_pts = norm_sample_pts.transpose(0, 1).contiguous() + norm_sample_pts = norm_sample_pts.view(1, 1, -1, 2).float() + sample_desc = torch.nn.functional.grid_sample(coarse_desc[None], norm_sample_pts, mode='bilinear', + align_corners=False) + sample_desc = torch.nn.functional.normalize(sample_desc, dim=1).squeeze(2).squeeze(0) + return sample_desc + + +def extract_sp_return(model, img, conf_th=0.005, + mask=None, + topK=-1, + **kwargs): + old_bm = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = False # speedup + + # print(img.shape) + img = img.cuda() + # if len(img.shape) == 3: # gray image + # img = img[None] + + B, one, H, W = img.shape + + all_pts = [] + all_descs = [] + + if 'scales' in kwargs.keys(): + scales = kwargs.get('scales') + else: + scales = [1.0] + + for s in scales: + if s == 1.0: + new_img = img + else: + nh = int(H * s) + nw = int(W * s) + new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True) + nh, nw = new_img.shape[2:] + + with torch.no_grad(): + heatmap, coarse_desc = model.det(new_img) + + # print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape) + if len(heatmap.size()) == 3: + heatmap = heatmap.unsqueeze(1) + if len(heatmap.size()) == 2: + heatmap = heatmap.unsqueeze(0) + heatmap = heatmap.unsqueeze(1) + # print(heatmap.shape) + if heatmap.size(2) != nh or heatmap.size(3) != nw: + heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True) + + conf_thresh = conf_th + nms_dist = 4 + border_remove = 4 + scores = simple_nms(heatmap, nms_radius=nms_dist) + keypoints = [ + torch.nonzero(s > conf_thresh) + for s in scores] + scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] + # print(keypoints[0].shape) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + scores = scores[0].data.cpu().numpy().squeeze() + keypoints = keypoints[0].data.cpu().numpy().squeeze() + pts = keypoints.transpose() + pts[2, :] = scores + + inds = np.argsort(pts[2, :]) + pts = pts[:, inds[::-1]] # Sort by confidence. + # Remove points along border. + bord = border_remove + toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord)) + toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord)) + toremove = np.logical_or(toremoveW, toremoveH) + pts = pts[:, ~toremove] + + # valid_idex = heatmap > conf_thresh + # valid_score = heatmap[valid_idex] + # """ + # --- Process descriptor. + # coarse_desc = coarse_desc.data.cpu().numpy().squeeze() + D = coarse_desc.size(1) + if pts.shape[1] == 0: + desc = np.zeros((D, 0)) + else: + if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw: + desc = coarse_desc[:, :, pts[1, :], pts[0, :]] + desc = desc.data.cpu().numpy().reshape(D, -1) + else: + # Interpolate into descriptor map using 2D point locations. + samp_pts = torch.from_numpy(pts[:2, :].copy()) + samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1. + samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1. + samp_pts = samp_pts.transpose(0, 1).contiguous() + samp_pts = samp_pts.view(1, 1, -1, 2) + samp_pts = samp_pts.float() + samp_pts = samp_pts.cuda() + desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True) + desc = desc.data.cpu().numpy().reshape(D, -1) + desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :] + + if pts.shape[1] == 0: + continue + + # print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H) + pts[0, :] = pts[0, :] * W / nw + pts[1, :] = pts[1, :] * H / nh + all_pts.append(np.transpose(pts, [1, 0])) + all_descs.append(np.transpose(desc, [1, 0])) + + all_pts = np.vstack(all_pts) + all_descs = np.vstack(all_descs) + + torch.backends.cudnn.benchmark = old_bm + + if all_pts.shape[0] == 0: + return None, None, None + + keypoints = all_pts[:, 0:2] + scores = all_pts[:, 2] + descriptors = all_descs + + if mask is not None: + # cv2.imshow("mask", mask) + # cv2.waitKey(0) + labels = [] + others = [] + keypoints_with_labels = [] + scores_with_labels = [] + descriptors_with_labels = [] + keypoints_without_labels = [] + scores_without_labels = [] + descriptors_without_labels = [] + + id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0]) + # print(img.shape, id_img.shape) + + for i in range(keypoints.shape[0]): + x = keypoints[i, 0] + y = keypoints[i, 1] + # print("x-y", x, y, int(x), int(y)) + gid = id_img[int(y), int(x)] + if gid == 0: + keypoints_without_labels.append(keypoints[i]) + scores_without_labels.append(scores[i]) + descriptors_without_labels.append(descriptors[i]) + others.append(0) + else: + keypoints_with_labels.append(keypoints[i]) + scores_with_labels.append(scores[i]) + descriptors_with_labels.append(descriptors[i]) + labels.append(gid) + + if topK > 0: + if topK <= len(keypoints_with_labels): + idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK] + keypoints = np.array(keypoints_with_labels, float)[idxes] + scores = np.array(scores_with_labels, float)[idxes] + labels = np.array(labels, np.int32)[idxes] + descriptors = np.array(descriptors_with_labels, float)[idxes] + elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels): + # keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels]) + # scores = np.vstack([scorescc_with_labels, scores_without_labels]) + # descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels]) + # labels = np.vstack([labels, others]) + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in range(len(others)): + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + else: + n = topK - len(keypoints_with_labels) + idxes = np.array(scores_without_labels, float).argsort()[::-1][:n] + keypoints = keypoints_with_labels + scores = scores_with_labels + descriptors = descriptors_with_labels + for i in idxes: + keypoints.append(keypoints_without_labels[i]) + scores.append(scores_without_labels[i]) + descriptors.append(descriptors_without_labels[i]) + labels.append(others[i]) + keypoints = np.array(keypoints, float) + descriptors = np.array(descriptors, float) + # print(keypoints.shape, descriptors.shape) + return {"keypoints": np.array(keypoints, float), + "descriptors": np.array(descriptors, float), + "scores": np.array(scores, float), + "labels": np.array(labels, np.int32), + } + else: + # print(topK) + if topK > 0: + idxes = np.array(scores, dtype=float).argsort()[::-1][:topK] + keypoints = np.array(keypoints[idxes], dtype=float) + scores = np.array(scores[idxes], dtype=float) + descriptors = np.array(descriptors[idxes], dtype=float) + + keypoints = np.array(keypoints, dtype=float) + scores = np.array(scores, dtype=float) + descriptors = np.array(descriptors, dtype=float) + + # print(keypoints.shape, descriptors.shape) + + return {"keypoints": np.array(keypoints, dtype=float), + "descriptors": descriptors, + "scores": scores, + } diff --git a/imcui/third_party/pram/nets/utils.py b/imcui/third_party/pram/nets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..066a00510c19e0c87cf5d07a36cea2a90dd0e3eb --- /dev/null +++ b/imcui/third_party/pram/nets/utils.py @@ -0,0 +1,24 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> utils +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 10:48 +==================================================''' +import torch + +eps = 1e-8 + + +def arange_like(x, dim: int): + return x.new_ones(x.shape[dim]).cumsum(0) - 1 + + +def normalize_keypoints(kpts, image_shape): + """ Normalize keypoints locations based on image image_shape""" + _, _, height, width = image_shape + one = kpts.new_tensor(1) + size = torch.stack([one * width, one * height])[None] + center = size / 2 + scaling = size.max(1, keepdim=True).values * 0.7 + return (kpts - center[:, None, :]) / scaling[:, None, :] diff --git a/imcui/third_party/pram/recognition/recmap.py b/imcui/third_party/pram/recognition/recmap.py new file mode 100644 index 0000000000000000000000000000000000000000..c159de286e96fdb594428e88e370e1a7edbecb79 --- /dev/null +++ b/imcui/third_party/pram/recognition/recmap.py @@ -0,0 +1,1118 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> recmap +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 11:02 +==================================================''' +import argparse +import torch +import os +import os.path as osp +import numpy as np +import cv2 +import yaml +import multiprocessing as mp +from copy import deepcopy +import logging +import h5py +from tqdm import tqdm +import open3d as o3d +from sklearn.cluster import KMeans, Birch +from collections import defaultdict +from colmap_utils.read_write_model import read_model, qvec2rotmat, write_cameras_binary, write_images_binary +from colmap_utils.read_write_model import write_points3d_binary, Image, Point3D, Camera +from colmap_utils.read_write_model import write_compressed_points3d_binary, write_compressed_images_binary +from recognition.vis_seg import generate_color_dic, vis_seg_point, plot_kpts + + +class RecMap: + def __init__(self): + self.cameras = None + self.images = None + self.points3D = None + self.pcd = o3d.geometry.PointCloud() + self.seg_color_dict = generate_color_dic(n_seg=1000) + + def load_sfm_model(self, path: str, ext='.bin'): + self.cameras, self.images, self.points3D = read_model(path, ext) + self.name_to_id = {image.name: i for i, image in self.images.items()} + print('Load {:d} cameras, {:d} images, {:d} points'.format(len(self.cameras), len(self.images), + len(self.points3D))) + + def remove_statics_outlier(self, nb_neighbors: int = 20, std_ratio: float = 2.0): + xyzs = [] + p3d_ids = [] + for p3d_id in self.points3D.keys(): + xyzs.append(self.points3D[p3d_id].xyz) + p3d_ids.append(p3d_id) + + xyzs = np.array(xyzs) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(xyzs) + new_pcd, inlier_ids = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) + + new_point3Ds = {} + for i in inlier_ids: + new_point3Ds[p3d_ids[i]] = self.points3D[p3d_ids[i]] + self.points3D = new_point3Ds + n_outlier = xyzs.shape[0] - len(inlier_ids) + ratio = n_outlier / xyzs.shape[0] + print('Remove {:d} - {:d} = {:d}/{:.2f}% points'.format(xyzs.shape[0], len(inlier_ids), n_outlier, ratio * 100)) + + def load_segmentation(self, path: str): + data = np.load(path, allow_pickle=True)[()] + p3d_id = data['id'] + seg_id = data['label'] + self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + self.seg_p3d = {} + for pid in self.p3d_seg.keys(): + sid = self.p3d_seg[pid] + if sid not in self.seg_p3d.keys(): + self.seg_p3d[sid] = [pid] + else: + self.seg_p3d[sid].append(pid) + + if 'xyz' not in data.keys(): + all_xyz = [] + for pid in p3d_id: + xyz = self.points3D[pid].xyz + all_xyz.append(xyz) + data['xyz'] = np.array(all_xyz) + np.save(path, data) + print('Add xyz to ', path) + + def cluster(self, k=512, mode='xyz', min_obs=3, save_fn=None, method='kmeans', **kwargs): + if save_fn is not None: + if osp.isfile(save_fn): + print('{:s} exists.'.format(save_fn)) + return + all_xyz = [] + point3D_ids = [] + for p3d in self.points3D.values(): + track_len = len(p3d.point2D_idxs) + if track_len < min_obs: + continue + all_xyz.append(p3d.xyz) + point3D_ids.append(p3d.id) + + xyz = np.array(all_xyz) + point3D_ids = np.array(point3D_ids) + + if mode.find('x') < 0: + xyz[:, 0] = 0 + if mode.find('y') < 0: + xyz[:, 1] = 0 + if mode.find('z') < 0: + xyz[:, 2] = 0 + + if method == 'kmeans': + model = KMeans(n_clusters=k, random_state=0, verbose=True).fit(xyz) + elif method == 'birch': + model = Birch(threshold=kwargs.get('threshold'), n_clusters=k).fit(xyz) # 0.01 for indoor + else: + print('Method {:s} for clustering does not exist'.format(method)) + exit(0) + labels = np.array(model.labels_).reshape(-1) + if save_fn is not None: + np.save(save_fn, { + 'id': np.array(point3D_ids), # should be assigned to self.points3D_ids + 'label': np.array(labels), + 'xyz': np.array(all_xyz), + }) + + def assign_point3D_descriptor(self, feature_fn: str, save_fn=None, n_process=1): + ''' + assign each 3d point a descriptor for localization + :param feature_fn: file name of features [h5py] + :param save_fn: + :param n_process: + :return: + ''' + + def run(start_id, end_id, points3D_desc): + for pi in tqdm(range(start_id, end_id), total=end_id - start_id): + p3d_id = all_p3d_ids[pi] + img_list = self.points3D[p3d_id].image_ids + kpt_ids = self.points3D[p3d_id].point2D_idxs + all_descs = [] + for img_id, p2d_id in zip(img_list, kpt_ids): + if img_id not in self.images.keys(): + continue + img_fn = self.images[img_id].name + desc = feat_file[img_fn]['descriptors'][()].transpose()[p2d_id] + all_descs.append(desc) + + if len(all_descs) == 1: + points3D_desc[p3d_id] = all_descs[0] + else: + all_descs = np.array(all_descs) # [n, d] + dist = all_descs @ all_descs.transpose() # [n, n] + dist = 2 - 2 * dist + md_dist = np.median(dist, axis=-1) # [n] + min_id = np.argmin(md_dist) + points3D_desc[p3d_id] = all_descs[min_id] + + if osp.isfile(save_fn): + print('{:s} exists.'.format(save_fn)) + return + p3D_desc = {} + feat_file = h5py.File(feature_fn, 'r') + all_p3d_ids = sorted(self.points3D.keys()) + + if n_process > 1: + if len(all_p3d_ids) <= n_process: + run(start_id=0, end_id=len(all_p3d_ids), points3D_desc=p3D_desc) + else: + manager = mp.Manager() + output = manager.dict() # necessary otherwise empty + n_sample_per_process = len(all_p3d_ids) // n_process + jobs = [] + for i in range(n_process): + start_id = i * n_sample_per_process + if i == n_process - 1: + end_id = len(all_p3d_ids) + else: + end_id = (i + 1) * n_sample_per_process + p = mp.Process( + target=run, + args=(start_id, end_id, output), + ) + jobs.append(p) + p.start() + + for p in jobs: + p.join() + + p3D_desc = {} + for k in output.keys(): + p3D_desc[k] = output[k] + else: + run(start_id=0, end_id=len(all_p3d_ids), points3D_desc=p3D_desc) + + if save_fn is not None: + np.save(save_fn, p3D_desc) + + def reproject(self, img_id, xyzs): + qvec = self.images[img_id].qvec + Rcw = qvec2rotmat(qvec=qvec) + tvec = self.images[img_id].tvec + tcw = tvec.reshape(3, ) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3] = tcw + # intrinsics + cam = self.cameras[self.images[img_id].camera_id] + K = self.get_intrinsics_from_camera(camera=cam) + + xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1), dtype=float)]) + kpts = K @ ((Tcw @ xyzs_homo.transpose())[:3, :]) # [3, N] + kpts = kpts.transpose() # [N, 3] + kpts[:, 0] = kpts[:, 0] / kpts[:, 2] + kpts[:, 1] = kpts[:, 1] / kpts[:, 2] + + return kpts + + def find_covisible_frame_ids(self, image_id, images, points3D): + covis = defaultdict(int) + p3d_ids = images[image_id].point3D_ids + + for pid in p3d_ids: + if pid == -1: + continue + if pid not in points3D.keys(): + continue + for im in points3D[pid].image_ids: + covis[im] += 1 + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + ind_top = np.argsort(covis_num)[::-1] + sorted_covis_ids = [covis_ids[i] for i in ind_top] + return sorted_covis_ids + + def create_virtual_frame_3(self, save_fn=None, save_vrf_dir=None, show_time=-1, ignored_cameras=[], + min_cover_ratio=0.9, + depth_scale=1.2, + radius=15, + min_obs=120, + topk_imgs=500, + n_vrf=10, + covisible_frame=20, + **kwargs): + def reproject(img_id, xyzs): + qvec = self.images[img_id].qvec + Rcw = qvec2rotmat(qvec=qvec) + tvec = self.images[img_id].tvec + tcw = tvec.reshape(3, ) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3] = tcw + # intrinsics + cam = self.cameras[self.images[img_id].camera_id] + K = self.get_intrinsics_from_camera(camera=cam) + + xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1), dtype=float)]) + kpts = K @ ((Tcw @ xyzs_homo.transpose())[:3, :]) # [3, N] + kpts = kpts.transpose() # [N, 3] + kpts[:, 0] = kpts[:, 0] / kpts[:, 2] + kpts[:, 1] = kpts[:, 1] / kpts[:, 2] + + return kpts + + def find_best_vrf_by_covisibility(p3d_id_list): + all_img_ids = [] + all_xyzs = [] + + img_ids_full = [] + img_id_obs = {} + for pid in p3d_id_list: + if pid not in self.points3D.keys(): + continue + all_xyzs.append(self.points3D[pid].xyz) + + img_ids = self.points3D[pid].image_ids + for iid in img_ids: + if iid in all_img_ids: + continue + # valid_p3ds = [v for v in self.images[iid].point3D_ids if v > 0 and v in p3d_id_list] + if len(ignored_cameras) > 0: + ignore = False + img_name = self.images[iid].name + for c in ignored_cameras: + if img_name.find(c) >= 0: + ignore = True + break + if ignore: + continue + # valid_p3ds = np.intersect1d(np.array(self.images[iid].point3D_ids), np.array(p3d_id_list)).tolist() + valid_p3ds = [v for v in self.images[iid].point3D_ids if v > 0] + img_ids_full.append(iid) + if len(valid_p3ds) < min_obs: + continue + + all_img_ids.append(iid) + img_id_obs[iid] = len(valid_p3ds) + all_xyzs = np.array(all_xyzs) + + print('Find {} 3D points and {} images'.format(len(p3d_id_list), len(img_id_obs.keys()))) + top_img_ids_by_obs = sorted(img_id_obs.items(), key=lambda item: item[1], reverse=True) # [(key, value), ] + all_img_ids = [] + for item in top_img_ids_by_obs: + all_img_ids.append(item[0]) + if len(all_img_ids) >= topk_imgs: + break + + # all_img_ids = all_img_ids[:200] + if len(all_img_ids) == 0: + print('no valid img ids with obs over {:d}'.format(min_obs)) + all_img_ids = img_ids_full + + img_observations = {} + p3d_id_array = np.array(p3d_id_list) + for idx, img_id in enumerate(all_img_ids): + valid_p3ds = [v for v in self.images[img_id].point3D_ids if v > 0] + mask = np.array([False for i in range(len(p3d_id_list))]) + for pid in valid_p3ds: + found_idx = np.where(p3d_id_array == pid)[0] + if found_idx.shape[0] == 0: + continue + mask[found_idx[0]] = True + + img_observations[img_id] = mask + + unobserved_p3d_ids = np.array([True for i in range(len(p3d_id_list))]) + + candidate_img_ids = [] + total_cover_ratio = 0 + while total_cover_ratio < min_cover_ratio: + best_img_id = -1 + best_img_obs = -1 + for idx, im_id in enumerate(all_img_ids): + if im_id in candidate_img_ids: + continue + obs_i = np.sum(img_observations[im_id] * unobserved_p3d_ids) + if obs_i > best_img_obs: + best_img_id = im_id + best_img_obs = obs_i + + if best_img_id >= 0: + # keep the valid img_id + candidate_img_ids.append(best_img_id) + # update the unobserved mask + unobserved_p3d_ids[img_observations[best_img_id]] = False + total_cover_ratio = 1 - np.sum(unobserved_p3d_ids) / len(p3d_id_list) + print(len(candidate_img_ids), best_img_obs, best_img_obs / len(p3d_id_list), total_cover_ratio) + + if best_img_obs / len(p3d_id_list) < 0.01: + break + + if len(candidate_img_ids) >= n_vrf: + break + else: + break + + return candidate_img_ids + # return [(v, img_observations[v]) for v in candidate_img_ids] + + if save_vrf_dir is not None: + os.makedirs(save_vrf_dir, exist_ok=True) + + seg_ref = {} + for sid in self.seg_p3d.keys(): + if sid == -1: # ignore invalid segment + continue + all_p3d_ids = self.seg_p3d[sid] + candidate_img_ids = find_best_vrf_by_covisibility(p3d_id_list=all_p3d_ids) + + seg_ref[sid] = {} + for can_idx, img_id in enumerate(candidate_img_ids): + cam = self.cameras[self.images[img_id].camera_id] + width = cam.width + height = cam.height + qvec = self.images[img_id].qvec + tvec = self.images[img_id].tvec + + img_name = self.images[img_id].name + orig_p3d_ids = [p for p in self.images[img_id].point3D_ids if p in self.points3D.keys() and p >= 0] + orig_xyzs = [] + new_xyzs = [] + for pid in all_p3d_ids: + if pid in orig_p3d_ids: + orig_xyzs.append(self.points3D[pid].xyz) + else: + if pid in self.points3D.keys(): + new_xyzs.append(self.points3D[pid].xyz) + + if len(orig_xyzs) == 0: + continue + + orig_xyzs = np.array(orig_xyzs) + new_xyzs = np.array(new_xyzs) + + print('img: ', osp.join(kwargs.get('image_root'), img_name)) + img = cv2.imread(osp.join(kwargs.get('image_root'), img_name)) + orig_kpts = reproject(img_id=img_id, xyzs=orig_xyzs) + max_depth = depth_scale * np.max(orig_kpts[:, 2]) + orig_kpts = orig_kpts[:, :2] + mask_ori = (orig_kpts[:, 0] >= 0) & (orig_kpts[:, 0] < width) & (orig_kpts[:, 1] >= 0) & ( + orig_kpts[:, 1] < height) + orig_kpts = orig_kpts[mask_ori] + + if orig_kpts.shape[0] == 0: + continue + + img_kpt = plot_kpts(img=img, kpts=orig_kpts, radius=[3 for i in range(orig_kpts.shape[0])], + colors=[(0, 0, 255) for i in range(orig_kpts.shape[0])], thickness=-1) + if new_xyzs.shape[0] == 0: + img_all = img_kpt + else: + new_kpts = reproject(img_id=img_id, xyzs=new_xyzs) + mask_depth = (new_kpts[:, 2] > 0) & (new_kpts[:, 2] <= max_depth) + mask_in_img = (new_kpts[:, 0] >= 0) & (new_kpts[:, 0] < width) & (new_kpts[:, 1] >= 0) & ( + new_kpts[:, 1] < height) + dist_all_orig = torch.from_numpy(new_kpts[:, :2])[..., None] - \ + torch.from_numpy(orig_kpts[:, :2].transpose())[None] + dist_all_orig = torch.sqrt(torch.sum(dist_all_orig ** 2, dim=1)) # [N, M] + min_dist = torch.min(dist_all_orig, dim=1)[0].numpy() + mask_close_to_img = (min_dist <= radius) + + mask_new = (mask_depth & mask_in_img & mask_close_to_img) + + cover_ratio = np.sum(mask_ori) + np.sum(mask_new) + cover_ratio = cover_ratio / len(all_p3d_ids) + + print('idx: {:d}, img: ori {:d}/{:d}/{:.2f}, new {:d}/{:d}'.format(can_idx, + orig_kpts.shape[0], + np.sum(mask_ori), + cover_ratio * 100, + new_kpts.shape[0], + np.sum(mask_new))) + + new_kpts = new_kpts[mask_new] + + # img_all = img_kpt + img_all = plot_kpts(img=img_kpt, kpts=new_kpts, radius=[3 for i in range(new_kpts.shape[0])], + colors=[(0, 255, 0) for i in range(new_kpts.shape[0])], thickness=-1) + + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + cv2.imshow('img', img_all) + + if save_vrf_dir is not None: + cv2.imwrite(osp.join(save_vrf_dir, + 'seg-{:05d}_can-{:05d}_'.format(sid, can_idx) + img_name.replace('/', '+')), + img_all) + + key = cv2.waitKey(show_time) + if key == ord('q'): + cv2.destroyAllWindows() + exit(0) + + covisile_frame_ids = self.find_covisible_frame_ids(image_id=img_id, images=self.images, + points3D=self.points3D) + seg_ref[sid][can_idx] = { + 'image_name': img_name, + 'image_id': img_id, + 'qvec': deepcopy(qvec), + 'tvec': deepcopy(tvec), + 'camera': { + 'model': cam.model, + 'params': cam.params, + 'width': cam.width, + 'height': cam.height, + }, + 'original_points3d': np.array( + [v for v in self.images[img_id].point3D_ids if v >= 0 and v in self.points3D.keys()]), + 'covisible_frame_ids': np.array(covisile_frame_ids[:covisible_frame]), + } + # save vrf info + if save_fn is not None: + print('Save {} segments with virtual reference image information to {}'.format(len(seg_ref.keys()), + save_fn)) + np.save(save_fn, seg_ref) + + def visualize_3Dpoints(self): + xyz = [] + rgb = [] + for point3D in self.points3D.values(): + xyz.append(point3D.xyz) + rgb.append(point3D.rgb / 255) + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(xyz) + pcd.colors = o3d.utility.Vector3dVector(rgb) + o3d.visualization.draw_geometries([pcd]) + + def visualize_segmentation(self, p3d_segs, points3D): + p3d_ids = p3d_segs.keys() + xyzs = [] + rgbs = [] + for pid in p3d_ids: + xyzs.append(points3D[pid].xyz) + seg_color = self.seg_color_dict[p3d_segs[pid]] + rgbs.append(np.array([seg_color[2], seg_color[1], seg_color[0]]) / 255) + xyzs = np.array(xyzs) + rgbs = np.array(rgbs) + + self.pcd.points = o3d.utility.Vector3dVector(xyzs) + self.pcd.colors = o3d.utility.Vector3dVector(rgbs) + + o3d.visualization.draw_geometries([self.pcd]) + + def visualize_segmentation_on_image(self, p3d_segs, image_path, feat_path): + vis_color = generate_color_dic(n_seg=1024) + feat_file = h5py.File(feat_path, 'r') + + cv2.namedWindow('img', cv2.WINDOW_NORMAL) + for mi in sorted(self.images.keys()): + im = self.images[mi] + im_name = im.name + p3d_ids = im.point3D_ids + p2ds = feat_file[im_name]['keypoints'][()] + image = cv2.imread(osp.join(image_path, im_name)) + print('img_name: ', im_name) + + sems = [] + for pid in p3d_ids: + if pid in p3d_segs.keys(): + sems.append(p3d_segs[pid] + 1) + else: + sems.append(0) + sems = np.array(sems) + + sems = np.array(sems) + mask = sems > 0 + img_seg = vis_seg_point(img=image, kpts=p2ds[mask], segs=sems[mask], seg_color=vis_color) + + cv2.imshow('img', img_seg) + key = cv2.waitKey(0) + if key == ord('q'): + exit(0) + elif key == ord('r'): + # cv2.destroyAllWindows() + return + + def extract_query_p3ds(self, log_fn, feat_fn, save_fn=None): + if save_fn is not None: + if osp.isfile(save_fn): + print('{:s} exists'.format(save_fn)) + return + + loc_log = np.load(log_fn, allow_pickle=True)[()] + fns = loc_log.keys() + feat_file = h5py.File(feat_fn, 'r') + + out = {} + for fn in tqdm(fns, total=len(fns)): + matched_kpts = loc_log[fn]['keypoints_query'] + matched_p3ds = loc_log[fn]['points3D_ids'] + + query_kpts = feat_file[fn]['keypoints'][()].astype(float) + query_p3d_ids = np.zeros(shape=(query_kpts.shape[0],), dtype=int) - 1 + print('matched kpts: {}, query kpts: {}'.format(matched_kpts.shape[0], query_kpts.shape[0])) + + if matched_kpts.shape[0] > 0: + # [M, 2, 1] - [1, 2, N] = [M, 2, N] + dist = torch.from_numpy(matched_kpts).unsqueeze(-1) - torch.from_numpy( + query_kpts.transpose()).unsqueeze(0) + dist = torch.sum(dist ** 2, dim=1) # [M, N] + values, idxes = torch.topk(dist, dim=1, largest=False, k=1) # find the matches kpts with dist of 0 + values = values.numpy() + idxes = idxes.numpy() + for i in range(values.shape[0]): + if values[i, 0] < 1: + query_p3d_ids[idxes[i, 0]] = matched_p3ds[i] + + out[fn] = query_p3d_ids + np.save(save_fn, out) + feat_file.close() + + def compute_mean_scale_p3ds(self, min_obs=5, save_fn=None): + if save_fn is not None: + if osp.isfile(save_fn): + with open(save_fn, 'r') as f: + lines = f.readlines() + l = lines[0].strip().split() + self.mean_xyz = np.array([float(v) for v in l[:3]]) + self.scale_xyz = np.array([float(v) for v in l[3:]]) + print('{} exists'.format(save_fn)) + return + + all_xyzs = [] + for pid in self.points3D.keys(): + p3d = self.points3D[pid] + obs = len(p3d.point2D_idxs) + if obs < min_obs: + continue + all_xyzs.append(p3d.xyz) + + all_xyzs = np.array(all_xyzs) + mean_xyz = np.ceil(np.mean(all_xyzs, axis=0)) + all_xyz_ = all_xyzs - mean_xyz + + dx = np.max(abs(all_xyz_[:, 0])) + dy = np.max(abs(all_xyz_[:, 1])) + dz = np.max(abs(all_xyz_[:, 2])) + scale_xyz = np.ceil(np.array([dx, dy, dz], dtype=float).reshape(3, )) + scale_xyz[scale_xyz < 1] = 1 + scale_xyz[scale_xyz == 0] = 1 + + # self.mean_xyz = mean_xyz + # self.scale_xyz = scale_xyz + # + # if save_fn is not None: + # with open(save_fn, 'w') as f: + # text = '{:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'.format(mean_xyz[0], mean_xyz[1], mean_xyz[2], + # scale_xyz[0], scale_xyz[1], scale_xyz[2]) + # f.write(text + '\n') + + def compute_statics_inlier(self, xyz, nb_neighbors=20, std_ratio=2.0): + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(xyz) + + new_pcd, inlier_ids = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) + return inlier_ids + + def export_features_to_directory(self, feat_fn, save_dir, with_descriptors=True): + def print_grp_name(grp_name, object): + try: + n_subgroups = len(object.keys()) + except: + n_subgroups = 0 + dataset_list.append(object.name) + + dataset_list = [] + feat_file = h5py.File(feat_fn, 'r') + feat_file.visititems(print_grp_name) + all_keys = [] + os.makedirs(save_dir, exist_ok=True) + for fn in dataset_list: + subs = fn[1:].split('/')[:-1] # remove the first '/' + subs = '/'.join(map(str, subs)) + if subs in all_keys: + continue + all_keys.append(subs) + + for fn in tqdm(all_keys, total=len(all_keys)): + feat = feat_file[fn] + data = { + # 'descriptors': feat['descriptors'][()].transpose(), + 'scores': feat['scores'][()], + 'keypoints': feat['keypoints'][()], + 'image_size': feat['image_size'][()] + } + np.save(osp.join(save_dir, fn.replace('/', '+')), data) + feat_file.close() + + def get_intrinsics_from_camera(self, camera): + if camera.model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): + fx = fy = camera.params[0] + cx = camera.params[1] + cy = camera.params[2] + elif camera.model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): + fx = camera.params[0] + fy = camera.params[1] + cx = camera.params[2] + cy = camera.params[3] + else: + raise Exception("Camera model not supported") + + # intrinsics + K = np.identity(3) + K[0, 0] = fx + K[1, 1] = fy + K[0, 2] = cx + K[1, 2] = cy + return K + + def compress_map_by_projection_v2(self, vrf_path, point3d_desc_path, vrf_frames=1, covisible_frames=20, radius=20, + nkpts=-1, save_dir=None): + def sparsify_by_grid(h, w, uvs, scores): + nh = np.ceil(h / radius).astype(int) + nw = np.ceil(w / radius).astype(int) + grid = {} + for ip in range(uvs.shape[0]): + p = uvs[ip] + iw = np.rint(p[0] // radius).astype(int) + ih = np.rint(p[1] // radius).astype(int) + idx = ih * nw + iw + if idx in grid.keys(): + if scores[ip] <= grid[idx]['score']: + continue + else: + grid[idx]['score'] = scores[ip] + grid[idx]['ip'] = ip + else: + grid[idx] = { + 'score': scores[ip], + 'ip': ip + } + + retained_ips = [grid[v]['ip'] for v in grid.keys()] + retained_ips = np.array(retained_ips) + return retained_ips + + def choose_valid_p3ds(current_frame_id, covisible_frame_ids, reserved_images): + curr_p3d_ids = [] + curr_xyzs = [] + for pid in self.images[current_frame_id].point3D_ids: + if pid == -1: + continue + if pid not in self.points3D.keys(): + continue + curr_p3d_ids.append(pid) + curr_xyzs.append(self.points3D[pid].xyz) + curr_xyzs = np.array(curr_xyzs) # [N, 3] + curr_xyzs_homo = np.hstack([curr_xyzs, np.ones((curr_xyzs.shape[0], 1), dtype=curr_xyzs.dtype)]) # [N, 4] + + curr_mask = np.array([True for mi in range(curr_xyzs.shape[0])]) # keep all at first + for iim in covisible_frame_ids: + cam_id = self.images[iim].camera_id + width = self.cameras[cam_id].width + height = self.cameras[cam_id].height + qvec = self.images[iim].qvec + tcw = self.images[iim].tvec + Rcw = qvec2rotmat(qvec=qvec) + Tcw = np.eye(4, dtype=float) + Tcw[:3, :3] = Rcw + Tcw[:3, 3] = tcw.reshape(3, ) + + uvs = reserved_images[iim]['xys'] + K = self.get_intrinsics_from_camera(camera=self.cameras[cam_id]) + proj_xys = K @ (Tcw @ curr_xyzs_homo.transpose())[:3, :] # [3, ] + proj_xys = proj_xys.transpose() + depth = proj_xys[:, 2] + proj_xys[:, 0] = proj_xys[:, 0] / depth + proj_xys[:, 1] = proj_xys[:, 1] / depth + + mask_in_image = (proj_xys[:, 0] >= 0) * (proj_xys[:, 0] < width) * (proj_xys[:, 1] >= 0) * ( + proj_xys[:, 1] < height) + mask_depth = proj_xys[:, 2] > 0 + + dist_proj_uv = torch.from_numpy(proj_xys[:, :2])[..., None] - \ + torch.from_numpy(uvs[:, :2].transpose())[None] + dist_proj_uv = torch.sqrt(torch.sum(dist_proj_uv ** 2, dim=1)) # [N, M] + min_dist = torch.min(dist_proj_uv, dim=1)[0].numpy() + mask_close_to_img = (min_dist <= radius) + + mask = mask_in_image * mask_depth * mask_close_to_img # p3ds to be discarded + + curr_mask = curr_mask * (1 - mask) + + chosen_p3d_ids = [] + for mi in range(curr_mask.shape[0]): + if curr_mask[mi]: + chosen_p3d_ids.append(curr_p3d_ids[mi]) + + return chosen_p3d_ids + + vrf_data = np.load(vrf_path, allow_pickle=True)[()] + p3d_ids_in_vrf = [] + image_ids_in_vrf = [] + for sid in vrf_data.keys(): + svrf = vrf_data[sid] + svrf_keys = [vi for vi in range(vrf_frames)] + for vi in svrf_keys: + if vi not in svrf.keys(): + continue + image_id = svrf[vi]['image_id'] + if image_id in image_ids_in_vrf: + continue + image_ids_in_vrf.append(image_id) + for pid in svrf[vi]['original_points3d']: + if pid in p3d_ids_in_vrf: + continue + p3d_ids_in_vrf.append(pid) + + print('Find {:d} images and {:d} 3D points in vrf'.format(len(image_ids_in_vrf), len(p3d_ids_in_vrf))) + + # first_vrf_images_covis = {} + retained_image_ids = {} + for frame_id in image_ids_in_vrf: + observed = self.images[frame_id].point3D_ids + xys = self.images[frame_id].xys + covis = defaultdict(int) + valid_xys = [] + valid_p3d_ids = [] + for xy, pid in zip(xys, observed): + if pid == -1: + continue + if pid not in self.points3D.keys(): + continue + valid_xys.append(xy) + valid_p3d_ids.append(pid) + for img_id in self.points3D[pid].image_ids: + covis[img_id] += 1 + + retained_image_ids[frame_id] = { + 'xys': np.array(valid_xys), + 'p3d_ids': valid_p3d_ids, + } + + print('Find {:d} valid connected frames'.format(len(covis.keys()))) + + covis_ids = np.array(list(covis.keys())) + covis_num = np.array([covis[i] for i in covis_ids]) + + if len(covis_ids) <= covisible_frames: + sel_covis_ids = covis_ids[np.argsort(-covis_num)] + else: + ind_top = np.argpartition(covis_num, -covisible_frames) + ind_top = ind_top[-covisible_frames:] # unsorted top k + ind_top = ind_top[np.argsort(-covis_num[ind_top])] + sel_covis_ids = [covis_ids[i] for i in ind_top] + + covis_frame_ids = [frame_id] + for iim in sel_covis_ids: + if iim == frame_id: + continue + if iim in retained_image_ids.keys(): + covis_frame_ids.append(iim) + continue + + chosen_p3d_ids = choose_valid_p3ds(current_frame_id=iim, covisible_frame_ids=covis_frame_ids, + reserved_images=retained_image_ids) + if len(chosen_p3d_ids) == 0: + continue + + xys = [] + for xy, pid in zip(self.images[iim].xys, self.images[iim].point3D_ids): + if pid in chosen_p3d_ids: + xys.append(xy) + xys = np.array(xys) + + covis_frame_ids.append(iim) + retained_image_ids[iim] = { + 'xys': xys, + 'p3d_ids': chosen_p3d_ids, + } + + new_images = {} + new_point3Ds = {} + new_cameras = {} + for iim in retained_image_ids.keys(): + p3d_ids = retained_image_ids[iim]['p3d_ids'] + ''' this step reduces the performance + for v in self.images[iim].point3D_ids: + if v == -1 or v not in self.points3D: + continue + if v in p3d_ids: + continue + p3d_ids.append(v) + ''' + + xyzs = np.array([self.points3D[pid].xyz for pid in p3d_ids]) + obs = np.array([len(self.points3D[pid].point2D_idxs) for pid in p3d_ids]) + xys = self.images[iim].xys + cam_id = self.images[iim].camera_id + name = self.images[iim].name + qvec = self.images[iim].qvec + tvec = self.images[iim].tvec + + if nkpts > 0 and len(p3d_ids) > nkpts: + proj_uvs = self.reproject(img_id=iim, xyzs=xyzs) + width = self.cameras[cam_id].width + height = self.cameras[cam_id].height + sparsified_idxs = sparsify_by_grid(h=height, w=width, uvs=proj_uvs[:, :2], scores=obs) + + print('org / new kpts: ', len(p3d_ids), sparsified_idxs.shape) + + p3d_ids = [p3d_ids[k] for k in sparsified_idxs] + + new_images[iim] = Image(id=iim, qvec=qvec, tvec=tvec, + camera_id=cam_id, + name=name, + xys=np.array([]), + point3D_ids=np.array(p3d_ids)) + + if cam_id not in new_cameras.keys(): + new_cameras[cam_id] = self.cameras[cam_id] + + for pid in p3d_ids: + if pid in new_point3Ds.keys(): + new_point3Ds[pid]['image_ids'].append(iim) + else: + xyz = self.points3D[pid].xyz + rgb = self.points3D[pid].rgb + error = self.points3D[pid].error + + new_point3Ds[pid] = { + 'image_ids': [iim], + 'rgb': rgb, + 'xyz': xyz, + 'error': error + } + + new_point3Ds_to_save = {} + for pid in new_point3Ds.keys(): + image_ids = new_point3Ds[pid]['image_ids'] + if len(image_ids) == 0: + continue + xyz = new_point3Ds[pid]['xyz'] + rgb = new_point3Ds[pid]['rgb'] + error = new_point3Ds[pid]['error'] + + new_point3Ds_to_save[pid] = Point3D(id=pid, xyz=xyz, rgb=rgb, error=error, image_ids=np.array(image_ids), + point2D_idxs=np.array([])) + + print('Retain {:d}/{:d} images and {:d}/{:d} 3D points'.format(len(new_images), len(self.images), + len(new_point3Ds), len(self.points3D))) + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + # write_images_binary(images=new_image_ids, + # path_to_model_file=osp.join(save_dir, 'images.bin')) + # write_points3d_binary(points3D=new_point3Ds, + # path_to_model_file=osp.join(save_dir, 'points3D.bin')) + write_compressed_images_binary(images=new_images, + path_to_model_file=osp.join(save_dir, 'images.bin')) + write_cameras_binary(cameras=new_cameras, + path_to_model_file=osp.join(save_dir, 'cameras.bin')) + write_compressed_points3d_binary(points3D=new_point3Ds_to_save, + path_to_model_file=osp.join(save_dir, 'points3D.bin')) + + # Save 3d descriptors + p3d_desc = np.load(point3d_desc_path, allow_pickle=True)[()] + comp_p3d_desc = {} + for k in new_point3Ds_to_save.keys(): + if k not in p3d_desc.keys(): + print(k) + continue + comp_p3d_desc[k] = deepcopy(p3d_desc[k]) + np.save(osp.join(save_dir, point3d_desc_path.split('/')[-1]), comp_p3d_desc) + print('Save data to {:s}'.format(save_dir)) + + +def process_dataset(dataset, dataset_dir, sfm_dir, save_dir, feature='sfd2', matcher='gml'): + # dataset_dir = '/scratches/flyer_3/fx221/dataset' + # sfm_dir = '/scratches/flyer_2/fx221/localization/outputs' # your sfm results (cameras, images, points3D) and features + # save_dir = '/scratches/flyer_3/fx221/exp/localizer' + # local_feat = 'sfd2' + # matcher = 'gml' + # hloc_results_dir = '/scratches/flyer_2/fx221/exp/sgd2' + + # config_path = 'configs/datasets/CUED.yaml' + # config_path = 'configs/datasets/7Scenes.yaml' + # config_path = 'configs/datasets/12Scenes.yaml' + # config_path = 'configs/datasets/CambridgeLandmarks.yaml' + # config_path = 'configs/datasets/Aachen.yaml' + + # config_path = 'configs/datasets/Aria.yaml' + # config_path = 'configs/datasets/DarwinRGB.yaml' + # config_path = 'configs/datasets/ACUED.yaml' + # config_path = 'configs/datasets/JesusCollege.yaml' + # config_path = 'configs/datasets/CUED2Kings.yaml' + + config_path = 'configs/datasets/{:s}.yaml'.format(dataset) + with open(config_path, 'rt') as f: + configs = yaml.load(f, Loader=yaml.Loader) + print(configs) + + dataset = configs['dataset'] + all_scenes = configs['scenes'] + for scene in all_scenes: + n_cluster = configs[scene]['n_cluster'] + cluster_mode = configs[scene]['cluster_mode'] + cluster_method = configs[scene]['cluster_method'] + # if scene not in ['heads']: + # continue + + print('scene: ', scene, cluster_mode, cluster_method) + # hloc_path = osp.join(hloc_root, dataset, scene) + sfm_path = osp.join(sfm_dir, scene) + save_path = osp.join(save_dir, feature + '-' + matcher, dataset, scene) + + n_vrf = 1 + n_cov = 30 + radius = 20 + n_kpts = 0 + + if dataset in ['Aachen']: + image_path = osp.join(dataset_dir, scene, 'images/images_upright') + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 32 + + elif dataset in ['CambridgeLandmarks', ]: + image_path = osp.join(dataset_dir, scene) + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 64 + elif dataset in ['Aria']: + image_path = osp.join(dataset_dir, scene) + min_obs = 150 + filtering_outliers = False + threshold = 0.01 + radius = 15 + elif dataset in ['DarwinRGB']: + image_path = osp.join(dataset_dir, scene) + min_obs = 150 + filtering_outliers = True + threshold = 0.2 + radius = 16 + elif dataset in ['ACUED']: + image_path = osp.join(dataset_dir, scene) + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 32 + elif dataset in ['7Scenes', '12Scenes']: + image_path = osp.join(dataset_dir, scene) + min_obs = 150 + filtering_outliers = False + threshold = 0.01 + radius = 15 + else: + image_path = osp.join(dataset_dir, scene) + min_obs = 250 + filtering_outliers = True + threshold = 0.2 + radius = 32 + + # comp_map_sub_path = 'comp_model_n{:d}_{:s}_{:s}_vrf{:d}_cov{:d}_r{:d}_np{:d}_projection_v2'.format(n_cluster, + # cluster_mode, + # cluster_method, + # n_vrf, + # n_cov, + # radius, + # n_kpts) + comp_map_sub_path = 'compress_model_{:s}'.format(cluster_method) + seg_fn = osp.join(save_path, + 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method)) + vrf_fn = osp.join(save_path, + 'point3D_vrf_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method)) + vrf_img_dir = osp.join(save_path, + 'point3D_vrf_n{:d}_{:s}_{:s}'.format(n_cluster, cluster_mode, cluster_method)) + # p3d_query_fn = osp.join(save_path, + # 'point3D_query_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method)) + comp_map_path = osp.join(save_path, comp_map_sub_path) + + os.makedirs(save_path, exist_ok=True) + + rmap = RecMap() + rmap.load_sfm_model(path=osp.join(sfm_path, 'sfm_{:s}-{:s}'.format(feature, matcher))) + if filtering_outliers: + rmap.remove_statics_outlier(nb_neighbors=20, std_ratio=2.0) + + # extract keypoints to train the recognition model (descriptors are recomputed from augmented db images) + # we do this for ddp training (reading h5py file is not supported) + rmap.export_features_to_directory(feat_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(feature)), + save_dir=osp.join(save_path, 'feats')) # only once for training + + rmap.cluster(k=n_cluster, mode=cluster_mode, save_fn=seg_fn, method=cluster_method, threshold=threshold) + # rmap.visualize_3Dpoints() + rmap.load_segmentation(path=seg_fn) + # rmap.visualize_segmentation(p3d_segs=rmap.p3d_seg, points3D=rmap.points3D) + + # Assign each 3D point a desciptor and discard all 2D images and descriptors - for localization + rmap.assign_point3D_descriptor( + feature_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(feature)), + save_fn=osp.join(save_path, 'point3D_desc.npy'.format(n_cluster, cluster_mode)), + n_process=32) # only once + + # exit(0) + # rmap.visualize_segmentation_on_image(p3d_segs=rmap.p3d_seg, image_path=image_path, feat_path=feat_path) + + # for query images only - for evaluation + # rmap.extract_query_p3ds( + # log_fn=osp.join(hloc_path, 'hloc_feats-{:s}_{:s}_loc.npy'.format(local_feat, matcher)), + # feat_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(local_feat)), + # save_fn=p3d_query_fn) + # continue + + # up-to-date + rmap.create_virtual_frame_3( + save_fn=vrf_fn, + save_vrf_dir=vrf_img_dir, + image_root=image_path, + show_time=5, + min_cover_ratio=0.9, + radius=radius, + depth_scale=2.5, # 1.2 by default + min_obs=min_obs, + n_vrf=10, + covisible_frame=n_cov, + ignored_cameras=[]) + + # up-to-date + rmap.compress_map_by_projection_v2( + vrf_frames=n_vrf, + vrf_path=vrf_fn, + point3d_desc_path=osp.join(save_path, 'point3D_desc.npy'), + save_dir=comp_map_path, + covisible_frames=n_cov, + radius=radius, + nkpts=n_kpts, + ) + + # exit(0) + # soft_link_compress_path = osp.join(save_path, 'compress_model_{:s}'.format(cluster_method)) + os.chdir(save_path) + # if osp.isdir(soft_link_compress_path): + # os.unlink(soft_link_compress_path) + # os.symlink(comp_map_sub_path, 'compress_model_{:s}'.format(cluster_method)) + # create a soft link of the full model for training + if not osp.isdir('model'): + os.symlink(osp.join(sfm_path, 'sfm_{:s}-{:s}'.format(feature, matcher)), '3D-models') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, required=True, help='dataset name') + parser.add_argument('--dataset_dir', type=str, required=True, help='dataset dir') + parser.add_argument('--sfm_dir', type=str, required=True, help='sfm dir') + parser.add_argument('--save_dir', type=str, required=True, help='dir to save the landmarks data') + parser.add_argument('--feature', type=str, default='sfd2', help='feature name e.g., SP, SFD2') + parser.add_argument('--matcher', type=str, default='gml', help='matcher name e.g., SG, LSG, gml') + + args = parser.parse_args() + + process_dataset( + dataset=args.dataset, + dataset_dir=args.dataset_dir, + sfm_dir=args.sfm_dir, + save_dir=args.save_dir, + feature=args.feature, + matcher=args.matcher) diff --git a/imcui/third_party/pram/recognition/vis_seg.py b/imcui/third_party/pram/recognition/vis_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef9b2365787e5921a66c74ff6c0b5ec3e49a31a --- /dev/null +++ b/imcui/third_party/pram/recognition/vis_seg.py @@ -0,0 +1,225 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> vis_seg +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 11:06 +==================================================''' +import cv2 +import numpy as np +from copy import deepcopy + + +def myHash(text: str): + hash = 0 + for ch in text: + hash = (hash * 7879 ^ ord(ch) * 5737) & 0xFFFFFFFF + return hash + + +def generate_color_dic(n_seg=1000): + out = {} + for i in range(n_seg + 1): + sid = i + if sid == 0: + color = (0, 0, 255) # [b, g, r] + else: + # rgb_new = hash(str(sid * 319993)) + rgb_new = myHash(str(sid * 319993)) + r = (rgb_new & 0xFF0000) >> 16 + g = (rgb_new & 0x00FF00) >> 8 + b = rgb_new & 0x0000FF + color = (b, g, r) + out[i] = color + return out + + +def vis_seg_point(img, kpts, segs=None, seg_color=None, radius=7, thickness=-1): + outimg = deepcopy(img) + for i in range(kpts.shape[0]): + # print(kpts[i]) + if segs is not None and seg_color is not None: + color = seg_color[segs[i]] + else: + color = (0, 255, 0) + outimg = cv2.circle(outimg, + center=(int(kpts[i, 0]), int(kpts[i, 1])), + color=color, + radius=radius, + thickness=thickness, ) + + return outimg + + +def vis_corr_incorr_point(img, kpts, pred_segs, gt_segs, radius=7, thickness=-1): + outimg = deepcopy(img) + for i in range(kpts.shape[0]): + # print(kpts[i]) + p_seg = pred_segs[i] + g_seg = gt_segs[i] + if p_seg == g_seg: + if g_seg != 0: + color = (0, 255, 0) + else: + color = (255, 0, 0) + else: + color = (0, 0, 255) + outimg = cv2.circle(outimg, + center=(int(kpts[i, 0]), int(kpts[i, 1])), + color=color, + radius=radius, + thickness=thickness, ) + return outimg + + +def vis_inlier(img, kpts, inliers, radius=7, thickness=1, with_outlier=True): + outimg = deepcopy(img) + for i in range(kpts.shape[0]): + if not with_outlier: + if not inliers[i]: + continue + if inliers[i]: + color = (0, 255, 0) + else: + color = (0, 0, 255) + outimg = cv2.rectangle(outimg, + pt1=(int(kpts[i, 0] - radius), int(kpts[i, 1] - radius)), + pt2=(int(kpts[i, 0] + radius), int(kpts[i, 1] + radius)), + color=color, + thickness=thickness, ) + + return outimg + + +def vis_global_seg(cls, seg_color, radius=7, thickness=-1): + all_patches = [] + for i in range(cls.shape[0]): + if cls[i] == 0: + continue + color = seg_color[i] + patch = np.zeros(shape=(radius, radius, 3), dtype=np.uint8) + patch[..., 0] = color[0] + patch[..., 1] = color[1] + patch[..., 2] = color[2] + + all_patches.append(patch) + if len(all_patches) == 0: + color = seg_color[0] + patch = np.zeros(shape=(radius, radius, 3), dtype=np.uint8) + patch[..., 0] = color[0] + patch[..., 1] = color[1] + patch[..., 2] = color[2] + all_patches.append(patch) + return np.vstack(all_patches) + + +def plot_matches(img1, img2, pts1, pts2, inliers, radius=3, line_thickness=2, horizon=True, plot_outlier=False, + confs=None): + rows1 = img1.shape[0] + cols1 = img1.shape[1] + rows2 = img2.shape[0] + cols2 = img2.shape[1] + # r = 3 + if horizon: + img_out = np.zeros((max([rows1, rows2]), cols1 + cols2, 3), dtype='uint8') + # Place the first image to the left + img_out[:rows1, :cols1] = img1 + # Place the next image to the right of it + img_out[:rows2, cols1:] = img2 # np.dstack([img2, img2, img2]) + for idx in range(inliers.shape[0]): + # if idx % 10 > 0: + # continue + if inliers[idx]: + color = (0, 255, 0) + else: + if not plot_outlier: + continue + color = (0, 0, 255) + pt1 = pts1[idx] + pt2 = pts2[idx] + + if confs is not None: + nr = int(radius * confs[idx]) + else: + nr = radius + img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2) + + img_out = cv2.circle(img_out, (int(pt2[0]) + cols1, int(pt2[1])), nr, color, 2) + + img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]) + cols1, int(pt2[1])), color, + line_thickness) + else: + img_out = np.zeros((rows1 + rows2, max([cols1, cols2]), 3), dtype='uint8') + # Place the first image to the left + img_out[:rows1, :cols1] = img1 + # Place the next image to the right of it + img_out[rows1:, :cols2] = img2 # np.dstack([img2, img2, img2]) + + for idx in range(inliers.shape[0]): + # print("idx: ", inliers[idx]) + # if idx % 10 > 0: + # continue + if inliers[idx]: + color = (0, 255, 0) + else: + if not plot_outlier: + continue + color = (0, 0, 255) + + if confs is not None: + nr = int(radius * confs[idx]) + else: + nr = radius + + pt1 = pts1[idx] + pt2 = pts2[idx] + img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2) + + img_out = cv2.circle(img_out, (int(pt2[0]), int(pt2[1]) + rows1), nr, color, 2) + + img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1]) + rows1), color, + line_thickness) + + return img_out + + +def plot_kpts(img, kpts, radius=None, colors=None, r=3, color=(0, 0, 255), nh=-1, nw=-1, shape='o', show_text=None, + thickness=5): + img_out = deepcopy(img) + for i in range(kpts.shape[0]): + pt = kpts[i] + if radius is not None: + if shape == 'o': + img_out = cv2.circle(img_out, center=(int(pt[0]), int(pt[1])), radius=radius[i], + color=color if colors is None else colors[i], + thickness=thickness) + elif shape == '+': + img_out = cv2.line(img_out, pt1=(int(pt[0] - radius[i]), int(pt[1])), + pt2=(int(pt[0] + radius[i]), int(pt[1])), + color=color if colors is None else colors[i], + thickness=5) + img_out = cv2.line(img_out, pt1=(int(pt[0]), int(pt[1] - radius[i])), + pt2=(int(pt[0]), int(pt[1] + radius[i])), color=color, + thickness=thickness) + else: + if shape == 'o': + img_out = cv2.circle(img_out, center=(int(pt[0]), int(pt[1])), radius=r, + color=color if colors is None else colors[i], + thickness=thickness) + elif shape == '+': + img_out = cv2.line(img_out, pt1=(int(pt[0] - r), int(pt[1])), + pt2=(int(pt[0] + r), int(pt[1])), color=color if colors is None else colors[i], + thickness=thickness) + img_out = cv2.line(img_out, pt1=(int(pt[0]), int(pt[1] - r)), + pt2=(int(pt[0]), int(pt[1] + r)), color=color if colors is None else colors[i], + thickness=thickness) + + if show_text is not None: + img_out = cv2.putText(img_out, show_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, + (0, 0, 255), 3) + if nh == -1 and nw == -1: + return img_out + if nh > 0: + return cv2.resize(img_out, dsize=(int(img.shape[1] / img.shape[0] * nh), nh)) + if nw > 0: + return cv2.resize(img_out, dsize=(nw, int(img.shape[0] / img.shape[1] * nw))) diff --git a/imcui/third_party/pram/tools/common.py b/imcui/third_party/pram/tools/common.py new file mode 100644 index 0000000000000000000000000000000000000000..8990012575324ed593ebc07bec88d47602005d5f --- /dev/null +++ b/imcui/third_party/pram/tools/common.py @@ -0,0 +1,125 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> common +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 15:05 +==================================================''' +import os +import torch +import json +import yaml +import cv2 +import numpy as np +from typing import Tuple +from copy import deepcopy + + +def load_args(args, save_path): + with open(save_path, "r") as f: + args.__dict__ = json.load(f) + + +def save_args_yaml(args, save_path): + with open(save_path, 'w') as f: + yaml.dump(args, f) + + +def merge_tags(tags: list, connection='_'): + out = '' + for i, t in enumerate(tags): + if i == 0: + out = out + t + else: + out = out + connection + t + return out + + +def torch_set_gpu(gpus): + if type(gpus) is int: + gpus = [gpus] + + cuda = all(gpu >= 0 for gpu in gpus) + + if cuda: + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus]) + # print(os.environ['CUDA_VISIBLE_DEVICES']) + assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % ( + os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES']) + torch.backends.cudnn.benchmark = True # speed-up cudnn + torch.backends.cudnn.fastest = True # even more speed-up? + print('Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES']) + + else: + print('Launching on CPU') + + return cuda + + +def resize_img(img, nh=-1, nw=-1, rmax=-1, mode=cv2.INTER_NEAREST): + assert nh > 0 or nw > 0 or rmax > 0 + if nh > 0: + return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * nh), nh), interpolation=mode) + if nw > 0: + return cv2.resize(img, dsize=(nw, int(img.shape[0] / img.shape[1] * nw)), interpolation=mode) + if rmax > 0: + oh, ow = img.shape[0], img.shape[1] + if oh > ow: + return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * rmax), rmax), interpolation=mode) + else: + return cv2.resize(img, dsize=(rmax, int(img.shape[0] / img.shape[1] * rmax)), interpolation=mode) + + return cv2.resize(img, dsize=(nw, nh), interpolation=mode) + + +def resize_image_with_padding(image: np.array, nw: int, nh: int, padding_color: Tuple[int] = (0, 0, 0)) -> np.array: + """Maintains aspect ratio and resizes with padding. + Params: + image: Image to be resized. + new_shape: Expected (width, height) of new image. + padding_color: Tuple in BGR of padding color + Returns: + image: Resized image with padding + """ + original_shape = (image.shape[1], image.shape[0]) # (w, h) + ratio_w = nw / original_shape[0] + ratio_h = nh / original_shape[1] + + if ratio_w == ratio_h: + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_NEAREST) + + ratio = ratio_w if ratio_w < ratio_h else ratio_h + + new_size = tuple([int(x * ratio) for x in original_shape]) + image = cv2.resize(image, new_size, interpolation=cv2.INTER_NEAREST) + delta_w = nw - new_size[0] if nw > new_size[0] else new_size[0] - nw + delta_h = nh - new_size[1] if nh > new_size[1] else new_size[1] - nh + + left, right = delta_w // 2, delta_w - (delta_w // 2) + top, bottom = delta_h // 2, delta_h - (delta_h // 2) + + # print('top, bottom, left, right: ', top, bottom, left, right) + image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color) + return image + + +def puttext_with_background(image, text, org=(0, 0), fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=1, text_color=(0, 0, 255), + thickness=2, lineType=cv2.LINE_AA, bg_color=None): + out_img = deepcopy(image) + if bg_color is not None: + (text_width, text_height), baseline = cv2.getTextSize(text, + fontFace, + fontScale=fontScale, + thickness=thickness) + box_coords = ( + (org[0], org[1] + baseline), + (org[0] + text_width + 2, org[1] - text_height - 2)) + + cv2.rectangle(out_img, box_coords[0], box_coords[1], bg_color, cv2.FILLED) + out_img = cv2.putText(img=out_img, text=text, + org=org, + fontFace=fontFace, + fontScale=fontScale, color=text_color, + thickness=thickness, lineType=lineType) + return out_img diff --git a/imcui/third_party/pram/tools/geometry.py b/imcui/third_party/pram/tools/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..d781a4172dd7f6ad8a4a26e252f614483ebd01e3 --- /dev/null +++ b/imcui/third_party/pram/tools/geometry.py @@ -0,0 +1,74 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> geometry +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/02/2024 11:08 +==================================================''' +import numpy as np + + +def nms_fast(in_corners, H, W, dist_thresh): + """ + Run a faster approximate Non-Max-Suppression on numpy corners shaped: + 3xN [x_i,y_i,conf_i]^T + + Algo summary: Create a grid sized HxW. Assign each corner location a 1, rest + are zeros. Iterate through all the 1's and convert them either to -1 or 0. + Suppress points by setting nearby values to 0. + + Grid Value Legend: + -1 : Kept. + 0 : Empty or suppressed. + 1 : To be processed (converted to either kept or supressed). + + NOTE: The NMS first rounds points to integers, so NMS distance might not + be exactly dist_thresh. It also assumes points are within image boundaries. + + Inputs + in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T. + H - Image height. + W - Image width. + dist_thresh - Distance to suppress, measured as an infinty norm distance. + Returns + nmsed_corners - 3xN numpy matrix with surviving corners. + nmsed_inds - N length numpy vector with surviving corner indices. + """ + grid = np.zeros((H, W)).astype(int) # Track NMS data. + inds = np.zeros((H, W)).astype(int) # Store indices of points. + # Sort by confidence and round to nearest int. + inds1 = np.argsort(-in_corners[2, :]) + corners = in_corners[:, inds1] + rcorners = corners[:2, :].round().astype(int) # Rounded corners. + # Check for edge case of 0 or 1 corners. + if rcorners.shape[1] == 0: + return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int) + if rcorners.shape[1] == 1: + out = np.vstack((rcorners, in_corners[2])).reshape(3, 1) + return out, np.zeros((1)).astype(int) + # Initialize the grid. + for i, rc in enumerate(rcorners.T): + grid[rcorners[1, i], rcorners[0, i]] = 1 + inds[rcorners[1, i], rcorners[0, i]] = i + # Pad the border of the grid, so that we can NMS points near the border. + pad = dist_thresh + grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant') + # Iterate through points, highest to lowest conf, suppress neighborhood. + count = 0 + for i, rc in enumerate(rcorners.T): + # Account for top and left padding. + pt = (rc[0] + pad, rc[1] + pad) + if grid[pt[1], pt[0]] == 1: # If not yet suppressed. + grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0 + grid[pt[1], pt[0]] = -1 + count += 1 + # Get all surviving -1's and return sorted array of remaining corners. + keepy, keepx = np.where(grid == -1) + keepy, keepx = keepy - pad, keepx - pad + inds_keep = inds[keepy, keepx] + out = corners[:, inds_keep] + values = out[-1, :] + inds2 = np.argsort(-values) + out = out[:, inds2] + out_inds = inds1[inds_keep[inds2]] + return out_inds diff --git a/imcui/third_party/pram/tools/image_to_video.py b/imcui/third_party/pram/tools/image_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f281fd2cf0ef5eb2752117610c042b8764f5f1 --- /dev/null +++ b/imcui/third_party/pram/tools/image_to_video.py @@ -0,0 +1,66 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> image_to_video +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 07/09/2023 20:15 +==================================================''' +import cv2 +import os +import os.path as osp + +import numpy as np +from tqdm import tqdm +import argparse + +from tools.common import resize_img + +parser = argparse.ArgumentParser(description='Image2Video', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--image_dir', type=str, required=True) +parser.add_argument('--video_path', type=str, required=True) +parser.add_argument('--height', type=int, default=-1) +parser.add_argument('--fps', type=int, default=30) + + +def imgs2video(img_dir, video_path, fps=30, height=1024): + img_fns = os.listdir(img_dir) + # print(img_fns) + img_fns = [v for v in img_fns if v.split('.')[-1] in ['jpg', 'png']] + img_fns = sorted(img_fns) + # print(img_fns) + # 输出视频路径 + # fps = 1 + + img = cv2.imread(osp.join(img_dir, img_fns[0])) + if height == -1: + height = img.shape[1] + new_img = resize_img(img=img, nh=height) + img_size = (new_img.shape[1], height) + + # fourcc = cv2.cv.CV_FOURCC('M','J','P','G')#opencv2.4 + # fourcc = cv2.VideoWriter_fourcc('I','4','2','0') + + fourcc = cv2.VideoWriter_fourcc(*'MP4V') # 设置输出视频为mp4格式 + # fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') # 设置输出视频为mp4格式 + videoWriter = cv2.VideoWriter(video_path, fourcc, fps, img_size) + + for i in tqdm(range(3700, len(img_fns)), total=len(img_fns)): + # fn = img_fns[i].split('-') + im_name = os.path.join(img_dir, img_fns[i]) + print(im_name) + frame = cv2.imread(im_name, 1) + frame = np.flip(frame, 0) + + frame = cv2.resize(frame, dsize=img_size) + # print(frame.shape) + # exit(0) + cv2.imshow("frame", frame) + cv2.waitKey(1) + videoWriter.write(frame) + + videoWriter.release() + + +if __name__ == '__main__': + args = parser.parse_args() + imgs2video(img_dir=args.image_dir, video_path=args.video_path, fps=args.fps, height=args.height) diff --git a/imcui/third_party/pram/tools/metrics.py b/imcui/third_party/pram/tools/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..22e14374931fa9ba4151632b65b41c65d6ba55f7 --- /dev/null +++ b/imcui/third_party/pram/tools/metrics.py @@ -0,0 +1,216 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> metrics +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 16:32 +==================================================''' +import torch +import numpy as np +import torch.nn.functional as F + + +class SeqIOU: + def __init__(self, n_class, ignored_sids=[]): + self.n_class = n_class + self.ignored_sids = ignored_sids + self.class_iou = np.zeros(n_class) + self.precisions = [] + + def add(self, pred, target): + for i in range(self.n_class): + inter = np.sum((pred == target) * (target == i)) + union = np.sum(target == i) + np.sum(pred == i) - inter + if union > 0: + self.class_iou[i] = inter / union + + acc = (pred == target) + if len(self.ignored_sids) == 0: + acc_ratio = np.sum(acc) / pred.shape[0] + else: + pred_mask = (pred >= 0) + target_mask = (target >= 0) + for i in self.ignored_sids: + pred_mask = pred_mask & (pred == i) + target_mask = target_mask & (target == i) + + acc = acc & (1 - pred_mask) + tgt = (1 - target_mask) + if np.sum(tgt) == 0: + acc_ratio = 0 + else: + acc_ratio = np.sum(acc) / np.sum(tgt) + + self.precisions.append(acc_ratio) + + def get_mean_iou(self): + return np.mean(self.class_iou) + + def get_mean_precision(self): + return np.mean(self.precisions) + + def clear(self): + self.precisions = [] + self.class_iou = np.zeros(self.n_class) + + +def compute_iou(pred: np.ndarray, target: np.ndarray, n_class: int, ignored_ids=[]) -> float: + class_iou = np.zeros(n_class) + for i in range(n_class): + if i in ignored_ids: + continue + inter = np.sum((pred == target) * (target == i)) + union = np.sum(target == i) + np.sum(pred == i) - inter + if union > 0: + class_iou[i] = inter / union + + return np.mean(class_iou) + # return class_iou + + +def compute_precision(pred: np.ndarray, target: np.ndarray, ignored_ids: list = []) -> float: + acc = (pred == target) + if len(ignored_ids) == 0: + return np.sum(acc) / pred.shape[0] + else: + pred_mask = (pred >= 0) + target_mask = (target >= 0) + for i in ignored_ids: + pred_mask = pred_mask & (pred == i) + target_mask = target_mask & (target == i) + + acc = acc & (1 - pred_mask) + tgt = (1 - target_mask) + if np.sum(tgt) == 0: + return 0 + return np.sum(acc) / np.sum(tgt) + + +def compute_cls_corr(pred: torch.Tensor, target: torch.Tensor, k: int = 20) -> torch.Tensor: + bs = pred.shape[0] + _, target_ids = torch.topk(target, k=k, dim=1) + target_ids = target_ids.cpu().numpy() + _, top_ids = torch.topk(pred, k=k, dim=1) # [B, k, 1] + top_ids = top_ids.cpu().numpy() + acc = 0 + for i in range(bs): + # print('top_ids: ', i, top_ids[i], target_ids[i]) + overlap = [v for v in top_ids[i] if v in target_ids[i] and v >= 0] + acc = acc + len(overlap) / k + acc = acc / bs + return torch.from_numpy(np.array([acc])).to(pred.device) + + +def compute_corr_incorr(pred: torch.Tensor, target: torch.Tensor, ignored_ids: list = []) -> tuple: + ''' + :param pred: [B, N, C] + :param target: [B, N] + :param ignored_ids: [] + :return: + ''' + pred_ids = torch.max(pred, dim=-1)[1] + if len(ignored_ids) == 0: + acc = (pred_ids == target) + inacc = torch.logical_not(acc) + acc_ratio = torch.sum(acc) / torch.numel(target) + inacc_ratio = torch.sum(inacc) / torch.numel(target) + else: + acc = (pred_ids == target) + inacc = torch.logical_not(acc) + + mask = torch.zeros_like(acc) + for i in ignored_ids: + mask = torch.logical_and(mask, (target == i)) + + acc = torch.logical_and(acc, torch.logical_not(mask)) + acc_ratio = torch.sum(acc) / torch.numel(target) + inacc_ratio = torch.sum(inacc) / torch.numel(target) + + return acc_ratio, inacc_ratio + + +def compute_seg_loss_weight(pred: torch.Tensor, + target: torch.Tensor, + background_id: int = 0, + weight_background: float = 0.1) -> torch.Tensor: + ''' + :param pred: [B, C, N] + :param target: [B, N] + :param background_id: + :param weight_background: + :return: + ''' + pred = pred.transpose(-2, -1).contiguous() # [B, N, C] -> [B, C, N] + weight = torch.ones(size=(pred.shape[1],), device=pred.device).float() + pred = torch.log_softmax(pred, dim=1) + weight[background_id] = weight_background + seg_loss = F.cross_entropy(pred, target.long(), weight=weight) + return seg_loss + + +def compute_cls_loss_ce(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + cls_loss = torch.zeros(size=[], device=pred.device) + if len(pred.shape) == 2: + n_valid = torch.sum(target > 0) + cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred, target, reduction='sum') + cls_loss = cls_loss / n_valid + else: + for i in range(pred.shape[-1]): + cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred[..., i], target[..., i], reduction='sum') + n_valid = torch.sum(target > 0) + cls_loss = cls_loss / n_valid + + return cls_loss + + +def compute_cls_loss_kl(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + cls_loss = torch.zeros(size=[], device=pred.device) + if len(pred.shape) == 2: + cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred, dim=-1), + torch.softmax(target, dim=-1), + reduction='sum') + else: + for i in range(pred.shape[-1]): + cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred[..., i], dim=-1), + torch.softmax(target[..., i], dim=-1), + reduction='sum') + + cls_loss = cls_loss / pred.shape[-1] + + return cls_loss + + +def compute_sc_loss_l1(pred: torch.Tensor, target: torch.Tensor, mean_xyz=None, scale_xyz=None, mask=None): + ''' + :param pred: [B, N, C] + :param target: [B, N, C] + :param mean_xyz: + :param scale_xyz: + :param mask: + :return: + ''' + loss = (pred - target) + loss = torch.abs(loss).mean(dim=1) + if mask is not None: + return torch.mean(loss[mask]) + else: + return torch.mean(loss) + + +def compute_sc_loss_geo(pred: torch.Tensor, P, K, p2ds, mean_xyz, scale_xyz, max_value=20, mask=None): + b, c, n = pred.shape + p3ds = (pred * scale_xyz[..., None].repeat(1, 1, n) + mean_xyz[..., None].repeat(1, 1, n)) + p3ds_homo = torch.cat( + [pred, torch.ones(size=(p3ds.shape[0], 1, p3ds.shape[2]), dtype=p3ds.dtype, device=p3ds.device)], + dim=1) # [B, 4, N] + p3ds = torch.matmul(K, torch.matmul(P, p3ds_homo)[:, :3, :]) # [B, 3, N] + # print('p3ds: ', p3ds.shape, P.shape, K.shape, p2ds.shape) + + p2ds_ = p3ds[:, :2, :] / p3ds[:, 2:, :] + + loss = ((p2ds_ - p2ds.permute(0, 2, 1)) ** 2).sum(1) + loss = torch.clamp_max(loss, max=max_value) + if mask is not None: + return torch.mean(loss[mask]) + else: + return torch.mean(loss) diff --git a/imcui/third_party/pram/tools/video_to_image.py b/imcui/third_party/pram/tools/video_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..7283f3ba24d432410ea326a7d9aedbe011b60ed2 --- /dev/null +++ b/imcui/third_party/pram/tools/video_to_image.py @@ -0,0 +1,38 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File localizer -> video_to_image +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 13/01/2024 15:29 +==================================================''' +import argparse +import os +import os.path as osp +import cv2 + +parser = argparse.ArgumentParser(description='Image2Video', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--image_path', type=str, required=True) +parser.add_argument('--video_path', type=str, required=True) +parser.add_argument('--height', type=int, default=-1) +parser.add_argument('--sample_ratio', type=int, default=-1) + + +def main(args): + video = cv2.VideoCapture(args.video_path) + nframe = 0 + while True: + ret, frame = video.read() + if ret: + if args.sample_ratio > 0: + if nframe % args.sample_ratio != 0: + nframe += 1 + continue + cv2.imwrite(osp.join(args.image_path, '{:06d}.png'.format(nframe)), frame) + nframe += 1 + else: + break + + +if __name__ == '__main__': + args = parser.parse_args() + main(args=args) diff --git a/imcui/third_party/pram/tools/visualize_landmarks.py b/imcui/third_party/pram/tools/visualize_landmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..7f8bcba35c14b929de1159c3a9491a98e1f0aebb --- /dev/null +++ b/imcui/third_party/pram/tools/visualize_landmarks.py @@ -0,0 +1,171 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> visualize_landmarks +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 22/03/2024 10:39 +==================================================''' +import os +import os.path as osp +import numpy as np +from tqdm import tqdm +from colmap_utils.read_write_model import read_model, write_model, Point3D, Image, read_compressed_model +from recognition.vis_seg import generate_color_dic + + +def reconstruct_map(valid_image_ids, valid_p3d_ids, cameras, images, point3Ds, p3d_seg: dict): + new_point3Ds = {} + new_images = {} + + valid_p3d_ids_ = [] + for pid in tqdm(valid_p3d_ids, total=len(valid_p3d_ids)): + + if pid == -1: + continue + if pid not in point3Ds.keys(): + continue + + if pid not in p3d_seg.keys(): + continue + + sid = map_seg[pid] + if sid == -1: + continue + valid_p3d_ids_.append(pid) + + valid_p3d_ids = valid_p3d_ids_ + print('valid_p3ds: ', len(valid_p3d_ids)) + + # for im_id in tqdm(images.keys(), total=len(images.keys())): + for im_id in tqdm(valid_image_ids, total=len(valid_image_ids)): + im = images[im_id] + # print('im: ', im) + # exit(0) + pids = im.point3D_ids + valid_pids = [] + # for v in pids: + # if v not in valid_p3d_ids: + # valid_pids.append(-1) + # else: + # valid_pids.append(v) + + new_im = Image(id=im_id, qvec=im.qvec, tvec=im.tvec, camera_id=im.camera_id, name=im.name, xys=im.xys, + point3D_ids=pids) + new_images[im_id] = new_im + + for pid in tqdm(valid_p3d_ids, total=len(valid_p3d_ids)): + sid = map_seg[pid] + + xyz = points3D[pid].xyz + if show_2D: + xyz[1] = 0 + rgb = points3D[pid].rgb + else: + bgr = seg_color[sid + sid_start] + rgb = np.array([bgr[2], bgr[1], bgr[0]]) + + error = points3D[pid].error + + p3d = Point3D(id=pid, xyz=xyz, rgb=rgb, error=error, + image_ids=points3D[pid].image_ids, + point2D_idxs=points3D[pid].point2D_idxs) + new_point3Ds[pid] = p3d + + return cameras, new_images, new_point3Ds + + +if __name__ == '__main__': + save_root = '/scratches/flyer_3/fx221/exp/localizer/vis_clustering/' + seg_color = generate_color_dic(n_seg=2000) + data_root = '/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gm' + show_2D = False + + compress_map = False + # compress_map = True + + # scene = 'Aachen/Aachenv11' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n512_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n512_xz_birch.npy' + + # + # scene = 'CambridgeLandmarks/GreatCourt' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + + # scene = 'CambridgeLandmarks/KingsCollege' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 33 + # vrf_file_name = 'point3D_vrf_n32_xy_birch.npy' + + # scene = 'CambridgeLandmarks/StMarysChurch' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 32 * 4 + 1 + # vrf_file_name = 'point3D_vrf_n32_xz_birch.npy' + + # scene = '7Scenes/office' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 33 + + # scene = '7Scenes/chess' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n16_xz_birch.npy' + + # scene = '7Scenes/redkitchen' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()] + # sid_start = 16 * 5 + 1 + # vrf_file_name = 'point3D_vrf_n16_xz_birch.npy' + + # scene = '12Scenes/apt1/kitchen' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n16_xy_birch.npy' + + # data_root = '/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gml2' + # scene = 'JesusCollege/jesuscollege' + # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n256_xy_birch.npy'), allow_pickle=True)[()] + # sid_start = 1 + # vrf_file_name = 'point3D_vrf_n256_xy_birch.npy' + + scene = 'DarwinRGB/darwin' + seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n128_xy_birch.npy'), allow_pickle=True)[()] + sid_start = 1 + vrf_file_name = 'point3D_vrf_n128_xy_birch.npy' + + cameras, images, points3D = read_model(osp.join(data_root, scene, 'model'), ext='.bin') + print('Load {:d} 3D points from map'.format(len(points3D.keys()))) + + if compress_map: + vrf_data = np.load(osp.join(data_root, scene, vrf_file_name), allow_pickle=True)[()] + valid_image_ids = [vrf_data[v][0]['image_id'] for v in vrf_data.keys()] + else: + valid_image_ids = list(images.keys()) + + if compress_map: + _, _, compress_points3D = read_compressed_model(osp.join(data_root, scene, 'compress_model_birch'), + ext='.bin') + print('Load {:d} 3D points from compressed map'.format(len(compress_points3D.keys()))) + valid_p3d_ids = list(compress_points3D.keys()) + else: + valid_p3d_ids = list(points3D.keys()) + + save_path = osp.join(save_root, scene) + + if compress_map: + save_path = save_path + '_comp' + if show_2D: + save_path = save_path + '_2D' + + os.makedirs(save_path, exist_ok=True) + p3d_id = seg_data['id'] + seg_id = seg_data['label'] + map_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])} + + new_cameras, new_images, new_point3Ds = reconstruct_map(valid_image_ids=valid_image_ids, + valid_p3d_ids=valid_p3d_ids, cameras=cameras, images=images, + point3Ds=points3D, p3d_seg=map_seg) + + # write_model(cameras=cameras, images=images, points3D=new_point3Ds, + # path=save_path, ext='.bin') + write_model(cameras=new_cameras, images=new_images, points3D=new_point3Ds, path=save_path, ext='.bin') diff --git a/imcui/third_party/pram/train.py b/imcui/third_party/pram/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2657f455d29c7c7c5417d8efa7aacaef4207ed --- /dev/null +++ b/imcui/third_party/pram/train.py @@ -0,0 +1,170 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> train +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 03/04/2024 16:33 +==================================================''' +import argparse +import os +import os.path as osp +import torch +import torchvision.transforms.transforms as tvt +import yaml +import torch.utils.data as Data +import torch.multiprocessing as mp +import torch.distributed as dist + +from nets.sfd2 import load_sfd2 +from nets.segnet import SegNet +from nets.segnetvit import SegNetViT +from nets.load_segnet import load_segnet +from dataset.utils import collect_batch +from dataset.get_dataset import compose_datasets +from tools.common import torch_set_gpu +from trainer import Trainer + + +def get_model(config): + desc_dim = 256 if config['feature'] == 'spp' else 128 + if config['use_mid_feature']: + desc_dim = 256 + model_config = { + 'network': { + 'descriptor_dim': desc_dim, + 'n_layers': config['layers'], + 'ac_fn': config['ac_fn'], + 'norm_fn': config['norm_fn'], + 'n_class': config['n_class'], + 'output_dim': config['output_dim'], + # 'with_cls': config['with_cls'], + # 'with_sc': config['with_sc'], + 'with_score': config['with_score'], + } + } + + if config['network'] == 'segnet': + model = SegNet(model_config.get('network', {})) + config['with_cls'] = False + elif config['network'] == 'segnetvit': + model = SegNetViT(model_config.get('network', {})) + config['with_cls'] = False + else: + raise 'ERROR! {:s} model does not exist'.format(config['network']) + + return model + + +parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--config', type=str, required=True, help='config of specifications') +# parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks') +parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth') + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms): + print('In train_DDP..., rank: ', rank) + torch.cuda.set_device(rank) + + device = torch.device(f'cuda:{rank}') + if feat_model is not None: + feat_model.to(device) + model.to(device) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + setup(rank=rank, world_size=world_size) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, + shuffle=True, + rank=rank, + num_replicas=world_size, + drop_last=True, # important? + ) + train_loader = torch.utils.data.DataLoader(train_set, + batch_size=config['batch_size'] // world_size, + num_workers=config['workers'] // world_size, + # num_workers=1, + pin_memory=True, + # persistent_workers=True, + shuffle=False, # must be False + drop_last=True, + collate_fn=collect_batch, + prefetch_factor=4, + sampler=train_sampler) + config['local_rank'] = rank + + if rank == 0: + test_set = test_set + else: + test_set = None + + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set, + config=config, img_transforms=img_transforms) + trainer.train() + + +if __name__ == '__main__': + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = yaml.load(f, Loader=yaml.Loader) + torch_set_gpu(gpus=config['gpu']) + if config['local_rank'] == 0: + print(config) + + img_transforms = [] + img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + img_transforms = tvt.Compose(img_transforms) + + feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval() + print('Load SFD2 weight from {:s}'.format(args.feat_weight_path)) + + dataset = config['dataset'] + train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None) + if config['do_eval']: + test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None) + else: + test_set = None + config['n_class'] = train_set.n_class + # model = get_model(config=config) + model = load_segnet(network=config['network'], + n_class=config['n_class'], + desc_dim=256 if config['use_mid_feature'] else 128, + n_layers=config['layers'], + output_dim=config['output_dim']) + if config['local_rank'] == 0: + if config['resume_path'] is not None: # only for training + model.load_state_dict( + torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'], + strict=True) + print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path']))) + + if not config['with_dist'] or len(config['gpu']) == 1: + config['with_dist'] = False + model = model.cuda() + train_loader = Data.DataLoader(dataset=train_set, + shuffle=True, + batch_size=config['batch_size'], + drop_last=True, + collate_fn=collect_batch, + num_workers=config['workers']) + if test_set is not None: + test_loader = Data.DataLoader(dataset=test_set, + shuffle=False, + batch_size=1, + drop_last=False, + collate_fn=collect_batch, + num_workers=4) + else: + test_loader = None + trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader, + config=config, img_transforms=img_transforms) + trainer.train() + else: + mp.spawn(train_DDP, nprocs=len(config['gpu']), + args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms), + join=True) diff --git a/imcui/third_party/pram/trainer.py b/imcui/third_party/pram/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..002e349323ec587843ea4119a0bc32b343bd34dd --- /dev/null +++ b/imcui/third_party/pram/trainer.py @@ -0,0 +1,404 @@ +# -*- coding: UTF-8 -*- +'''================================================= +@Project -> File pram -> trainer +@IDE PyCharm +@Author fx221@cam.ac.uk +@Date 29/01/2024 15:04 +==================================================''' +import datetime +import os +import os.path as osp +import numpy as np +from pathlib import Path +from tensorboardX import SummaryWriter +from tqdm import tqdm +import torch.optim as optim +import torch.nn.functional as F + +import shutil +import torch +from torch.autograd import Variable +from tools.common import save_args_yaml, merge_tags +from tools.metrics import compute_iou, compute_precision, SeqIOU, compute_corr_incorr, compute_seg_loss_weight +from tools.metrics import compute_cls_loss_ce, compute_cls_corr + + +class Trainer: + def __init__(self, model, train_loader, feat_model=None, eval_loader=None, config=None, img_transforms=None): + self.model = model + self.train_loader = train_loader + self.eval_loader = eval_loader + self.config = config + self.with_aug = self.config['with_aug'] + self.with_cls = False # self.config['with_cls'] + self.with_sc = False # self.config['with_sc'] + self.img_transforms = img_transforms + self.feat_model = feat_model.cuda().eval() if feat_model is not None else None + + self.init_lr = self.config['lr'] + self.min_lr = self.config['min_lr'] + + params = [p for p in self.model.parameters() if p.requires_grad] + self.optimizer = optim.AdamW(params=params, lr=self.init_lr) + self.num_epochs = self.config['epochs'] + + if config['resume_path'] is not None: + log_dir = config['resume_path'].split('/')[-2] + resume_log = torch.load(osp.join(osp.join(config['save_path'], config['resume_path'])), map_location='cpu') + self.epoch = resume_log['epoch'] + 1 + if 'iteration' in resume_log.keys(): + self.iteration = resume_log['iteration'] + else: + self.iteration = len(self.train_loader) * self.epoch + self.min_loss = resume_log['min_loss'] + else: + self.iteration = 0 + self.epoch = 0 + self.min_loss = 1e10 + + now = datetime.datetime.now() + all_tags = [now.strftime("%Y%m%d_%H%M%S")] + dataset_name = merge_tags(self.config['dataset'], '') + all_tags = all_tags + [self.config['network'], 'L' + str(self.config['layers']), + dataset_name, + str(self.config['feature']), 'B' + str(self.config['batch_size']), + 'K' + str(self.config['max_keypoints']), 'od' + str(self.config['output_dim']), + 'nc' + str(self.config['n_class'])] + if self.config['use_mid_feature']: + all_tags.append('md') + # if self.with_cls: + # all_tags.append(self.config['cls_loss']) + # if self.with_sc: + # all_tags.append(self.config['sc_loss']) + if self.with_aug: + all_tags.append('A') + + all_tags.append(self.config['cluster_method']) + log_dir = merge_tags(tags=all_tags, connection='_') + + if config['local_rank'] == 0: + self.save_dir = osp.join(self.config['save_path'], log_dir) + os.makedirs(self.save_dir, exist_ok=True) + + print("save_dir: ", self.save_dir) + + self.log_file = open(osp.join(self.save_dir, "log.txt"), "a+") + save_args_yaml(args=config, save_path=Path(self.save_dir, "args.yaml")) + self.writer = SummaryWriter(self.save_dir) + + self.tag = log_dir + + self.do_eval = self.config['do_eval'] + if self.do_eval: + self.eval_fun = None + self.seq_metric = SeqIOU(n_class=self.config['n_class'], ignored_sids=[0]) + + def preprocess_input(self, pred): + for k in pred.keys(): + if k.find('name') >= 0: + continue + if k != 'image' and k != 'depth': + if type(pred[k]) == torch.Tensor: + pred[k] = Variable(pred[k].float().cuda()) + else: + pred[k] = Variable(torch.stack(pred[k]).float().cuda()) + + if self.with_aug: + new_scores = [] + new_descs = [] + global_descs = [] + with torch.no_grad(): + for i, im in enumerate(pred['image']): + img = torch.from_numpy(im[0]).cuda().float().permute(2, 0, 1) + # img = self.img_transforms(img)[None] + if self.img_transforms is not None: + img = self.img_transforms(img)[None] + else: + img = img[None] + out = self.feat_model.extract_local_global(data={'image': img}) + global_descs.append(out['global_descriptors']) + + seg_scores, seg_descs = self.feat_model.sample(score_map=out['score_map'], + semi_descs=out['mid_features'] if self.config[ + 'use_mid_feature'] else out['desc_map'], + kpts=pred['keypoints'][i], + norm_desc=self.config['norm_desc']) # [D, N] + new_scores.append(seg_scores[None]) + new_descs.append(seg_descs[None]) + pred['global_descriptors'] = global_descs + pred['scores'] = torch.cat(new_scores, dim=0) + pred['seg_descriptors'] = torch.cat(new_descs, dim=0).permute(0, 2, 1) # -> [B, N, D] + + def process_epoch(self): + self.model.train() + + epoch_cls_losses = [] + epoch_seg_losses = [] + epoch_losses = [] + epoch_acc_corr = [] + epoch_acc_incorr = [] + epoch_cls_acc = [] + + epoch_sc_losses = [] + + for bidx, pred in tqdm(enumerate(self.train_loader), total=len(self.train_loader)): + self.preprocess_input(pred) + if 0 <= self.config['its_per_epoch'] <= bidx: + break + + data = self.model(pred) + for k, v in pred.items(): + pred[k] = v + pred = {**pred, **data} + + seg_loss = compute_seg_loss_weight(pred=pred['prediction'], + target=pred['gt_seg'], + background_id=0, + weight_background=0.1) + acc_corr, acc_incorr = compute_corr_incorr(pred=pred['prediction'], + target=pred['gt_seg'], + ignored_ids=[0]) + + if self.with_cls: + pred_cls_dist = pred['classification'] + gt_cls_dist = pred['gt_cls_dist'] + if len(pred_cls_dist.shape) > 2: + gt_cls_dist_full = gt_cls_dist.unsqueeze(-1).repeat(1, 1, pred_cls_dist.shape[-1]) + else: + gt_cls_dist_full = gt_cls_dist.unsqueeze(-1) + cls_loss = compute_cls_loss_ce(pred=pred_cls_dist, target=gt_cls_dist_full) + loss = seg_loss + cls_loss + + # gt_n_seg = pred['gt_n_seg'] + cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist) + else: + loss = seg_loss + cls_loss = torch.zeros_like(seg_loss) + cls_acc = torch.zeros_like(seg_loss) + + if self.with_sc: + pass + else: + sc_loss = torch.zeros_like(seg_loss) + + epoch_losses.append(loss.item()) + epoch_seg_losses.append(seg_loss.item()) + epoch_cls_losses.append(cls_loss.item()) + epoch_sc_losses.append(sc_loss.item()) + + epoch_acc_corr.append(acc_corr.item()) + epoch_acc_incorr.append(acc_incorr.item()) + epoch_cls_acc.append(cls_acc.item()) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.iteration += 1 + + lr = min(self.config['lr'] * self.config['decay_rate'] ** (self.iteration - self.config['decay_iter']), + self.config['lr']) + if lr < self.min_lr: + lr = self.min_lr + + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + + if self.config['local_rank'] == 0 and bidx % self.config['log_intervals'] == 0: + print_text = 'Epoch [{:d}/{:d}], Step [{:d}/{:d}/{:d}], Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]'.format( + self.epoch, + self.num_epochs, bidx, + len(self.train_loader), + self.iteration, + seg_loss.item(), + cls_loss.item(), + sc_loss.item(), + loss.item(), + + np.mean(epoch_acc_corr), + np.mean(epoch_acc_incorr), + np.mean(epoch_cls_acc) + ) + + print(print_text) + self.log_file.write(print_text + '\n') + + info = { + 'lr': lr, + 'loss': loss.item(), + 'cls_loss': cls_loss.item(), + 'sc_loss': sc_loss.item(), + 'acc_corr': acc_corr.item(), + 'acc_incorr': acc_incorr.item(), + 'acc_cls': cls_acc.item(), + } + + for k, v in info.items(): + self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.iteration) + + if self.config['local_rank'] == 0: + print_text = 'Epoch [{:d}/{:d}], AVG Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]\n'.format( + self.epoch, + self.num_epochs, + np.mean(epoch_seg_losses), + np.mean(epoch_cls_losses), + np.mean(epoch_sc_losses), + np.mean(epoch_losses), + np.mean(epoch_acc_corr), + np.mean(epoch_acc_incorr), + np.mean(epoch_cls_acc), + ) + print(print_text) + self.log_file.write(print_text + '\n') + self.log_file.flush() + return np.mean(epoch_losses) + + def eval_seg(self, loader): + print('Start to do evaluation...') + + self.model.eval() + self.seq_metric.clear() + mean_iou_day = [] + mean_iou_night = [] + mean_prec_day = [] + mean_prec_night = [] + mean_cls_day = [] + mean_cls_night = [] + + for bid, pred in tqdm(enumerate(loader), total=len(loader)): + for k in pred.keys(): + if k.find('name') >= 0: + continue + if k != 'image' and k != 'depth': + if type(pred[k]) == torch.Tensor: + pred[k] = Variable(pred[k].float().cuda()) + elif type(pred[k]) == np.ndarray: + pred[k] = Variable(torch.from_numpy(pred[k]).float()[None].cuda()) + else: + pred[k] = Variable(torch.stack(pred[k]).float().cuda()) + + if self.with_aug: + with torch.no_grad(): + if isinstance(pred['image'][0], list): + img = pred['image'][0][0] + else: + img = pred['image'][0] + + img = torch.from_numpy(img).cuda().float().permute(2, 0, 1) + if self.img_transforms is not None: + img = self.img_transforms(img)[None] + else: + img = img[None] + + encoder_out = self.feat_model.extract_local_global(data={'image': img}) + global_descriptors = [encoder_out['global_descriptors']] + pred['global_descriptors'] = global_descriptors + if self.config['use_mid_feature']: + scores, descs = self.feat_model.sample(score_map=encoder_out['score_map'], + semi_descs=encoder_out['mid_features'], + kpts=pred['keypoints'][0], + norm_desc=self.config['norm_desc']) + # print('eval: ', scores.shape, descs.shape) + pred['scores'] = scores[None] + pred['seg_descriptors'] = descs[None].permute(0, 2, 1) # -> [B, N, D] + else: + pred['seg_descriptors'] = pred['descriptors'] + + image_name = pred['file_name'][0] + with torch.no_grad(): + out = self.model(pred) + pred = {**pred, **out} + + pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C] + pred_seg = pred_seg[0].cpu().numpy() + gt_seg = pred['gt_seg'][0].cpu().numpy() + iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=self.config['n_class'], ignored_ids=[0]) + prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0]) + + if self.with_cls: + pred_cls_dist = pred['classification'] + gt_cls_dist = pred['gt_cls_dist'] + cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist).item() + else: + cls_acc = 0. + + if image_name.find('night') >= 0: + mean_iou_night.append(iou) + mean_prec_night.append(prec) + mean_cls_night.append(cls_acc) + else: + mean_iou_day.append(iou) + mean_prec_day.append(prec) + mean_cls_day.append(cls_acc) + + print_txt = 'Eval Epoch {:d}, iou day/night {:.3f}/{:.3f}, prec day/night {:.3f}/{:.3f}, cls day/night {:.3f}/{:.3f}'.format( + self.epoch, np.mean(mean_iou_day), np.mean(mean_iou_night), + np.mean(mean_prec_day), np.mean(mean_prec_night), + np.mean(mean_cls_day), np.mean(mean_cls_night)) + self.log_file.write(print_txt + '\n') + print(print_txt) + + info = { + 'mean_iou_day': np.mean(mean_iou_day), + 'mean_iou_night': np.mean(mean_iou_night), + 'mean_prec_day': np.mean(mean_prec_day), + 'mean_prec_night': np.mean(mean_prec_night), + } + + for k, v in info.items(): + self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.epoch) + + return np.mean(mean_prec_night) + + def train(self): + if self.config['local_rank'] == 0: + print('Start to train the model from epoch: {:d}'.format(self.epoch)) + hist_values = [] + min_value = self.min_loss + + epoch = self.epoch + while epoch < self.num_epochs: + if self.config['with_dist']: + self.train_loader.sampler.set_epoch(epoch=epoch) + self.epoch = epoch + + train_loss = self.process_epoch() + + # return with loss INF/NAN + if train_loss is None: + continue + + if self.config['local_rank'] == 0: + if self.do_eval and self.epoch % self.config['eval_n_epoch'] == 0: # and self.epoch >= 50: + eval_ratio = self.eval_seg(loader=self.eval_loader) + + hist_values.append(eval_ratio) # higher better + else: + hist_values.append(-train_loss) # lower better + + checkpoint_path = os.path.join(self.save_dir, + '%s.%02d.pth' % (self.config['network'], self.epoch)) + checkpoint = { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'model': self.model.state_dict(), + 'min_loss': min_value, + } + # for multi-gpu training + if len(self.config['gpu']) > 1: + checkpoint['model'] = self.model.module.state_dict() + + torch.save(checkpoint, checkpoint_path) + + if hist_values[-1] < min_value: + min_value = hist_values[-1] + best_checkpoint_path = os.path.join( + self.save_dir, + '%s.best.pth' % (self.tag) + ) + shutil.copy(checkpoint_path, best_checkpoint_path) + # important!!! + epoch += 1 + + if self.config['local_rank'] == 0: + self.log_file.close() diff --git a/imcui/third_party/r2d2/datasets/__init__.py b/imcui/third_party/r2d2/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f11df21be72856ea365f6efd7a389aba267562b --- /dev/null +++ b/imcui/third_party/r2d2/datasets/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +from .pair_dataset import CatPairDataset, SyntheticPairDataset, TransformedPairs +from .imgfolder import ImgFolder + +from .web_images import RandomWebImages +from .aachen import * + +# try to instanciate datasets +import sys +try: + web_images = RandomWebImages(0, 52) +except AssertionError as e: + print(f"Dataset web_images not available, reason: {e}", file=sys.stderr) + +try: + aachen_db_images = AachenImages_DB() +except AssertionError as e: + print(f"Dataset aachen_db_images not available, reason: {e}", file=sys.stderr) + +try: + aachen_style_transfer_pairs = AachenPairs_StyleTransferDayNight() +except AssertionError as e: + print(f"Dataset aachen_style_transfer_pairs not available, reason: {e}", file=sys.stderr) + +try: + aachen_flow_pairs = AachenPairs_OpticalFlow() +except AssertionError as e: + print(f"Dataset aachen_flow_pairs not available, reason: {e}", file=sys.stderr) + + diff --git a/imcui/third_party/r2d2/datasets/aachen.py b/imcui/third_party/r2d2/datasets/aachen.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddb324cea01da2430ee89b32c7627b34c01a41f --- /dev/null +++ b/imcui/third_party/r2d2/datasets/aachen.py @@ -0,0 +1,146 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import os, pdb +import numpy as np +from PIL import Image + +from .dataset import Dataset +from .pair_dataset import PairDataset, StillPairDataset + + +class AachenImages (Dataset): + """ Loads all images from the Aachen Day-Night dataset + """ + def __init__(self, select='db day night', root='data/aachen'): + Dataset.__init__(self) + self.root = root + self.img_dir = 'images_upright' + self.select = set(select.split()) + assert self.select, 'Nothing was selected' + + self.imgs = [] + root = os.path.join(root, self.img_dir) + for dirpath, _, filenames in os.walk(root): + r = dirpath[len(root)+1:] + if not(self.select & set(r.split('/'))): continue + self.imgs += [os.path.join(r,f) for f in filenames if f.endswith('.jpg')] + + self.nimg = len(self.imgs) + assert self.nimg, 'Empty Aachen dataset' + + def get_key(self, idx): + return self.imgs[idx] + + + +class AachenImages_DB (AachenImages): + """ Only database (db) images. + """ + def __init__(self, **kw): + AachenImages.__init__(self, select='db', **kw) + self.db_image_idxs = {self.get_tag(i) : i for i,f in enumerate(self.imgs)} + + def get_tag(self, idx): + # returns image tag == img number (name) + return os.path.split( self.imgs[idx][:-4] )[1] + + + +class AachenPairs_StyleTransferDayNight (AachenImages_DB, StillPairDataset): + """ synthetic day-night pairs of images + (night images obtained using autoamtic style transfer from web night images) + """ + def __init__(self, root='data/aachen/style_transfer', **kw): + StillPairDataset.__init__(self) + AachenImages_DB.__init__(self, **kw) + old_root = os.path.join(self.root, self.img_dir) + self.root = os.path.commonprefix((old_root, root)) + self.img_dir = '' + + newpath = lambda folder, f: os.path.join(folder, f)[len(self.root):] + self.imgs = [newpath(old_root, f) for f in self.imgs] + + self.image_pairs = [] + for fname in os.listdir(root): + tag = fname.split('.jpg.st_')[0] + self.image_pairs.append((self.db_image_idxs[tag], len(self.imgs))) + self.imgs.append(newpath(root, fname)) + + self.nimg = len(self.imgs) + self.npairs = len(self.image_pairs) + assert self.nimg and self.npairs + + + +class AachenPairs_OpticalFlow (AachenImages_DB, PairDataset): + """ Image pairs from Aachen db with optical flow. + """ + def __init__(self, root='data/aachen/optical_flow', **kw): + PairDataset.__init__(self) + AachenImages_DB.__init__(self, **kw) + self.root_flow = root + + # find out the subsest of valid pairs from the list of flow files + flows = {f for f in os.listdir(os.path.join(root, 'flow')) if f.endswith('.png')} + masks = {f for f in os.listdir(os.path.join(root, 'mask')) if f.endswith('.png')} + assert flows == masks, 'Missing flow or mask pairs' + + make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split('_')) + self.image_pairs = [make_pair(f) for f in flows] + self.npairs = len(self.image_pairs) + assert self.nimg and self.npairs + + def get_mask_filename(self, pair_idx): + tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) + return os.path.join(self.root_flow, 'mask', f'{tag_a}_{tag_b}.png') + + def get_mask(self, pair_idx): + return np.asarray(Image.open(self.get_mask_filename(pair_idx))) + + def get_flow_filename(self, pair_idx): + tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) + return os.path.join(self.root_flow, 'flow', f'{tag_a}_{tag_b}.png') + + def get_flow(self, pair_idx): + fname = self.get_flow_filename(pair_idx) + try: + return self._png2flow(fname) + except IOError: + flow = open(fname[:-4], 'rb') + help = np.fromfile(flow, np.float32, 1) + assert help == 202021.25 + W, H = np.fromfile(flow, np.int32, 2) + flow = np.fromfile(flow, np.float32).reshape((H, W, 2)) + return self._flow2png(flow, fname) + + def get_pair(self, idx, output=()): + if isinstance(output, str): + output = output.split() + + img1, img2 = map(self.get_image, self.image_pairs[idx]) + meta = {} + + if 'flow' in output or 'aflow' in output: + flow = self.get_flow(idx) + assert flow.shape[:2] == img1.size[::-1] + meta['flow'] = flow + H, W = flow.shape[:2] + meta['aflow'] = flow + np.mgrid[:H,:W][::-1].transpose(1,2,0) + + if 'mask' in output: + mask = self.get_mask(idx) + assert mask.shape[:2] == img1.size[::-1] + meta['mask'] = mask + + return img1, img2, meta + + + + +if __name__ == '__main__': + print(aachen_db_images) + print(aachen_style_transfer_pairs) + print(aachen_flow_pairs) + pdb.set_trace() diff --git a/imcui/third_party/r2d2/datasets/dataset.py b/imcui/third_party/r2d2/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..80d893b8ea4ead7845f35c4fe82c9f5a9b849de3 --- /dev/null +++ b/imcui/third_party/r2d2/datasets/dataset.py @@ -0,0 +1,77 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import os +import json +import pdb +import numpy as np + + +class Dataset(object): + ''' Base class for a dataset. To be overloaded. + ''' + root = '' + img_dir = '' + nimg = 0 + + def __len__(self): + return self.nimg + + def get_key(self, img_idx): + raise NotImplementedError() + + def get_filename(self, img_idx, root=None): + return os.path.join(root or self.root, self.img_dir, self.get_key(img_idx)) + + def get_image(self, img_idx): + from PIL import Image + fname = self.get_filename(img_idx) + try: + return Image.open(fname).convert('RGB') + except Exception as e: + raise IOError("Could not load image %s (reason: %s)" % (fname, str(e))) + + def __repr__(self): + res = 'Dataset: %s\n' % self.__class__.__name__ + res += ' %d images' % self.nimg + res += '\n root: %s...\n' % self.root + return res + + + +class CatDataset (Dataset): + ''' Concatenation of several datasets. + ''' + def __init__(self, *datasets): + assert len(datasets) >= 1 + self.datasets = datasets + offsets = [0] + for db in datasets: + offsets.append(db.nimg) + self.offsets = np.cumsum(offsets) + self.nimg = self.offsets[-1] + self.root = None + + def which(self, i): + pos = np.searchsorted(self.offsets, i, side='right')-1 + assert pos < self.nimg, 'Bad image index %d >= %d' % (i, self.nimg) + return pos, i - self.offsets[pos] + + def get_key(self, i): + b, i = self.which(i) + return self.datasets[b].get_key(i) + + def get_filename(self, i): + b, i = self.which(i) + return self.datasets[b].get_filename(i) + + def __repr__(self): + fmt_str = "CatDataset(" + for db in self.datasets: + fmt_str += str(db).replace("\n"," ") + ', ' + return fmt_str[:-2] + ')' + + + + diff --git a/imcui/third_party/r2d2/datasets/imgfolder.py b/imcui/third_party/r2d2/datasets/imgfolder.py new file mode 100644 index 0000000000000000000000000000000000000000..45f7bc9ee4c3ba5f04380dbc02ad17b6463cf32f --- /dev/null +++ b/imcui/third_party/r2d2/datasets/imgfolder.py @@ -0,0 +1,23 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import os, pdb + +from .dataset import Dataset +from .pair_dataset import SyntheticPairDataset + + +class ImgFolder (Dataset): + """ load all images in a folder (no recursion). + """ + def __init__(self, root, imgs=None, exts=('.jpg','.png','.ppm')): + Dataset.__init__(self) + self.root = root + self.imgs = imgs or [f for f in os.listdir(root) if f.endswith(exts)] + self.nimg = len(self.imgs) + + def get_key(self, idx): + return self.imgs[idx] + + diff --git a/imcui/third_party/r2d2/datasets/pair_dataset.py b/imcui/third_party/r2d2/datasets/pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..aeed98b6700e0ba108bb44abccc20351d16f3295 --- /dev/null +++ b/imcui/third_party/r2d2/datasets/pair_dataset.py @@ -0,0 +1,287 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import os, pdb +import numpy as np +from PIL import Image + +from .dataset import Dataset, CatDataset +from tools.transforms import instanciate_transformation +from tools.transforms_tools import persp_apply + + +class PairDataset (Dataset): + """ A dataset that serves image pairs with ground-truth pixel correspondences. + """ + def __init__(self): + Dataset.__init__(self) + self.npairs = 0 + + def get_filename(self, img_idx, root=None): + if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of filenames + return tuple(Dataset.get_filename(self, i, root) for i in img_idx) + return Dataset.get_filename(self, img_idx, root) + + def get_image(self, img_idx): + if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of images + return tuple(Dataset.get_image(self, i) for i in img_idx) + return Dataset.get_image(self, img_idx) + + def get_corres_filename(self, pair_idx): + raise NotImplementedError() + + def get_homography_filename(self, pair_idx): + raise NotImplementedError() + + def get_flow_filename(self, pair_idx): + raise NotImplementedError() + + def get_mask_filename(self, pair_idx): + raise NotImplementedError() + + def get_pair(self, idx, output=()): + """ returns (img1, img2, `metadata`) + + `metadata` is a dict() that can contain: + flow: optical flow + aflow: absolute flow + corres: list of 2d-2d correspondences + mask: boolean image of flow validity (in the first image) + ... + """ + raise NotImplementedError() + + def get_paired_images(self): + fns = set() + for i in range(self.npairs): + a,b = self.image_pairs[i] + fns.add(self.get_filename(a)) + fns.add(self.get_filename(b)) + return fns + + def __len__(self): + return self.npairs # size should correspond to the number of pairs, not images + + def __repr__(self): + res = 'Dataset: %s\n' % self.__class__.__name__ + res += ' %d images,' % self.nimg + res += ' %d image pairs' % self.npairs + res += '\n root: %s...\n' % self.root + return res + + @staticmethod + def _flow2png(flow, path): + flow = np.clip(np.around(16*flow), -2**15, 2**15-1) + bytes = np.int16(flow).view(np.uint8) + Image.fromarray(bytes).save(path) + return flow / 16 + + @staticmethod + def _png2flow(path): + try: + flow = np.asarray(Image.open(path)).view(np.int16) + return np.float32(flow) / 16 + except: + raise IOError("Error loading flow for %s" % path) + + + +class StillPairDataset (PairDataset): + """ A dataset of 'still' image pairs. + By overloading a normal image dataset, it appends the get_pair(i) function + that serves trivial image pairs (img1, img2) where img1 == img2 == get_image(i). + """ + def get_pair(self, pair_idx, output=()): + if isinstance(output, str): output = output.split() + img1, img2 = map(self.get_image, self.image_pairs[pair_idx]) + + W,H = img1.size + sx = img2.size[0] / float(W) + sy = img2.size[1] / float(H) + + meta = {} + if 'aflow' in output or 'flow' in output: + mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32) + meta['aflow'] = mgrid * (sx,sy) + meta['flow'] = meta['aflow'] - mgrid + + if 'mask' in output: + meta['mask'] = np.ones((H,W), np.uint8) + + if 'homography' in output: + meta['homography'] = np.diag(np.float32([sx, sy, 1])) + + return img1, img2, meta + + + +class SyntheticPairDataset (PairDataset): + """ A synthetic generator of image pairs. + Given a normal image dataset, it constructs pairs using random homographies & noise. + """ + def __init__(self, dataset, scale='', distort=''): + self.attach_dataset(dataset) + self.distort = instanciate_transformation(distort) + self.scale = instanciate_transformation(scale) + + def attach_dataset(self, dataset): + assert isinstance(dataset, Dataset) and not isinstance(dataset, PairDataset) + self.dataset = dataset + self.npairs = dataset.nimg + self.get_image = dataset.get_image + self.get_key = dataset.get_key + self.get_filename = dataset.get_filename + self.root = None + + def make_pair(self, img): + return img, img + + def get_pair(self, i, output=('aflow')): + """ Procedure: + This function applies a series of random transformations to one original image + to form a synthetic image pairs with perfect ground-truth. + """ + if isinstance(output, str): + output = output.split() + + original_img = self.dataset.get_image(i) + + scaled_image = self.scale(original_img) + scaled_image, scaled_image2 = self.make_pair(scaled_image) + scaled_and_distorted_image = self.distort( + dict(img=scaled_image2, persp=(1,0,0,0,1,0,0,0))) + W, H = scaled_image.size + trf = scaled_and_distorted_image['persp'] + + meta = dict() + if 'aflow' in output or 'flow' in output: + # compute optical flow + xy = np.mgrid[0:H,0:W][::-1].reshape(2,H*W).T + aflow = np.float32(persp_apply(trf, xy).reshape(H,W,2)) + meta['flow'] = aflow - xy.reshape(H,W,2) + meta['aflow'] = aflow + + if 'homography' in output: + meta['homography'] = np.float32(trf+(1,)).reshape(3,3) + + return scaled_image, scaled_and_distorted_image['img'], meta + + def __repr__(self): + res = 'Dataset: %s\n' % self.__class__.__name__ + res += ' %d images and pairs' % self.npairs + res += '\n root: %s...' % self.dataset.root + res += '\n Scale: %s' % (repr(self.scale).replace('\n','')) + res += '\n Distort: %s' % (repr(self.distort).replace('\n','')) + return res + '\n' + + + +class TransformedPairs (PairDataset): + """ Automatic data augmentation for pre-existing image pairs. + Given an image pair dataset, it generates synthetically jittered pairs + using random transformations (e.g. homographies & noise). + """ + def __init__(self, dataset, trf=''): + self.attach_dataset(dataset) + self.trf = instanciate_transformation(trf) + + def attach_dataset(self, dataset): + assert isinstance(dataset, PairDataset) + self.dataset = dataset + self.nimg = dataset.nimg + self.npairs = dataset.npairs + self.get_image = dataset.get_image + self.get_key = dataset.get_key + self.get_filename = dataset.get_filename + self.root = None + + def get_pair(self, i, output=''): + """ Procedure: + This function applies a series of random transformations to one original image + to form a synthetic image pairs with perfect ground-truth. + """ + img_a, img_b_, metadata = self.dataset.get_pair(i, output) + + img_b = self.trf({'img': img_b_, 'persp':(1,0,0,0,1,0,0,0)}) + trf = img_b['persp'] + + if 'aflow' in metadata or 'flow' in metadata: + aflow = metadata['aflow'] + aflow[:] = persp_apply(trf, aflow.reshape(-1,2)).reshape(aflow.shape) + W, H = img_a.size + flow = metadata['flow'] + mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32) + flow[:] = aflow - mgrid + + if 'corres' in metadata: + corres = metadata['corres'] + corres[:,1] = persp_apply(trf, corres[:,1]) + + if 'homography' in metadata: + # p_b = homography * p_a + trf_ = np.float32(trf+(1,)).reshape(3,3) + metadata['homography'] = np.float32(trf_ @ metadata['homography']) + + return img_a, img_b['img'], metadata + + def __repr__(self): + res = 'Transformed Pairs from %s\n' % type(self.dataset).__name__ + res += ' %d images and pairs' % self.npairs + res += '\n root: %s...' % self.dataset.root + res += '\n transform: %s' % (repr(self.trf).replace('\n','')) + return res + '\n' + + + +class CatPairDataset (CatDataset): + ''' Concatenation of several pair datasets. + ''' + def __init__(self, *datasets): + CatDataset.__init__(self, *datasets) + pair_offsets = [0] + for db in datasets: + pair_offsets.append(db.npairs) + self.pair_offsets = np.cumsum(pair_offsets) + self.npairs = self.pair_offsets[-1] + + def __len__(self): + return self.npairs + + def __repr__(self): + fmt_str = "CatPairDataset(" + for db in self.datasets: + fmt_str += str(db).replace("\n"," ") + ', ' + return fmt_str[:-2] + ')' + + def pair_which(self, i): + pos = np.searchsorted(self.pair_offsets, i, side='right')-1 + assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs) + return pos, i - self.pair_offsets[pos] + + def pair_call(self, func, i, *args, **kwargs): + b, j = self.pair_which(i) + return getattr(self.datasets[b], func)(j, *args, **kwargs) + + def get_pair(self, i, output=()): + b, i = self.pair_which(i) + return self.datasets[b].get_pair(i, output) + + def get_flow_filename(self, pair_idx, *args, **kwargs): + return self.pair_call('get_flow_filename', pair_idx, *args, **kwargs) + + def get_mask_filename(self, pair_idx, *args, **kwargs): + return self.pair_call('get_mask_filename', pair_idx, *args, **kwargs) + + def get_corres_filename(self, pair_idx, *args, **kwargs): + return self.pair_call('get_corres_filename', pair_idx, *args, **kwargs) + + + +def is_pair(x): + if isinstance(x, (tuple,list)) and len(x) == 2: + return True + if isinstance(x, np.ndarray) and x.ndim == 1 and x.shape[0] == 2: + return True + return False + diff --git a/imcui/third_party/r2d2/datasets/web_images.py b/imcui/third_party/r2d2/datasets/web_images.py new file mode 100644 index 0000000000000000000000000000000000000000..7c17fbe956f3b4db25d9a4148e8f7c615f122478 --- /dev/null +++ b/imcui/third_party/r2d2/datasets/web_images.py @@ -0,0 +1,64 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import os, pdb +from tqdm import trange + +from .dataset import Dataset + + +class RandomWebImages (Dataset): + """ 1 million distractors from Oxford and Paris Revisited + see http://ptak.felk.cvut.cz/revisitop/revisitop1m/ + """ + def __init__(self, start=0, end=1024, root="data/revisitop1m"): + Dataset.__init__(self) + self.root = root + + bar = None + self.imgs = [] + for i in range(start, end): + try: + # read cached list + img_list_path = os.path.join(self.root, "image_list_%d.txt"%i) + cached_imgs = [e.strip() for e in open(img_list_path)] + assert cached_imgs, f"Cache '{img_list_path}' is empty!" + self.imgs += cached_imgs + + except IOError: + if bar is None: + bar = trange(start, 4*end, desc='Caching') + bar.update(4*i) + + # create it + imgs = [] + for d in range(i*4,(i+1)*4): # 4096 folders in total, on average 256 each + key = hex(d)[2:].zfill(3) + folder = os.path.join(self.root, key) + if not os.path.isdir(folder): continue + imgs += [f for f in os.listdir(folder) if verify_img(folder,f)] + bar.update(1) + assert imgs, f"No images found in {folder}/" + open(img_list_path,'w').write('\n'.join(imgs)) + self.imgs += imgs + + if bar: bar.update(bar.total - bar.n) + self.nimg = len(self.imgs) + + def get_key(self, i): + key = self.imgs[i] + return os.path.join(key[:3], key) + + +def verify_img(folder, f): + path = os.path.join(folder, f) + if not f.endswith('.jpg'): return False + try: + from PIL import Image + Image.open(path).convert('RGB') # try to open it + return True + except: + return False + + diff --git a/imcui/third_party/r2d2/extract.py b/imcui/third_party/r2d2/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..c3fea02f87c0615504e3648bfd590e413ab13898 --- /dev/null +++ b/imcui/third_party/r2d2/extract.py @@ -0,0 +1,183 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + + +import os, pdb +from PIL import Image +import numpy as np +import torch + +from tools import common +from tools.dataloader import norm_RGB +from nets.patchnet import * + + +def load_network(model_fn): + checkpoint = torch.load(model_fn) + print("\n>> Creating net = " + checkpoint['net']) + net = eval(checkpoint['net']) + nb_of_weights = common.model_size(net) + print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )") + + # initialization + weights = checkpoint['state_dict'] + net.load_state_dict({k.replace('module.',''):v for k,v in weights.items()}) + return net.eval() + + +class NonMaxSuppression (torch.nn.Module): + def __init__(self, rel_thr=0.7, rep_thr=0.7): + nn.Module.__init__(self) + self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + self.rel_thr = rel_thr + self.rep_thr = rep_thr + + def forward(self, reliability, repeatability, **kw): + assert len(reliability) == len(repeatability) == 1 + reliability, repeatability = reliability[0], repeatability[0] + + # local maxima + maxima = (repeatability == self.max_filter(repeatability)) + + # remove low peaks + maxima *= (repeatability >= self.rep_thr) + maxima *= (reliability >= self.rel_thr) + + return maxima.nonzero().t()[2:4] + + +def extract_multiscale( net, img, detector, scale_f=2**0.25, + min_scale=0.0, max_scale=1, + min_size=256, max_size=1024, + verbose=False): + old_bm = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = False # speedup + + # extract keypoints at multiple scales + B, three, H, W = img.shape + assert B == 1 and three == 3, "should be a batch with a single RGB image" + + assert max_scale <= 1 + s = 1.0 # current scale factor + + X,Y,S,C,Q,D = [],[],[],[],[],[] + while s+0.001 >= max(min_scale, min_size / max(H,W)): + if s-0.001 <= min(max_scale, max_size / max(H,W)): + nh, nw = img.shape[2:] + if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") + # extract descriptors + with torch.no_grad(): + res = net(imgs=[img]) + + # get output and reliability map + descriptors = res['descriptors'][0] + reliability = res['reliability'][0] + repeatability = res['repeatability'][0] + + # normalize the reliability for nms + # extract maxima and descs + y,x = detector(**res) # nms + c = reliability[0,0,y,x] + q = repeatability[0,0,y,x] + d = descriptors[0,:,y,x].t() + n = d.shape[0] + + # accumulate multiple scales + X.append(x.float() * W/nw) + Y.append(y.float() * H/nh) + S.append((32/s) * torch.ones(n, dtype=torch.float32, device=d.device)) + C.append(c) + Q.append(q) + D.append(d) + s /= scale_f + + # down-scale the image for next iteration + nh, nw = round(H*s), round(W*s) + img = F.interpolate(img, (nh,nw), mode='bilinear', align_corners=False) + + # restore value + torch.backends.cudnn.benchmark = old_bm + + Y = torch.cat(Y) + X = torch.cat(X) + S = torch.cat(S) # scale + scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability + XYS = torch.stack([X,Y,S], dim=-1) + D = torch.cat(D) + return XYS, D, scores + + +def extract_keypoints(args): + iscuda = common.torch_set_gpu(args.gpu) + + # load the network... + net = load_network(args.model) + if iscuda: net = net.cuda() + + # create the non-maxima detector + detector = NonMaxSuppression( + rel_thr = args.reliability_thr, + rep_thr = args.repeatability_thr) + + while args.images: + img_path = args.images.pop(0) + + if img_path.endswith('.txt'): + args.images = open(img_path).read().splitlines() + args.images + continue + + print(f"\nExtracting features for {img_path}") + img = Image.open(img_path).convert('RGB') + W, H = img.size + img = norm_RGB(img)[None] + if iscuda: img = img.cuda() + + # extract keypoints/descriptors for a single image + xys, desc, scores = extract_multiscale(net, img, detector, + scale_f = args.scale_f, + min_scale = args.min_scale, + max_scale = args.max_scale, + min_size = args.min_size, + max_size = args.max_size, + verbose = True) + + xys = xys.cpu().numpy() + desc = desc.cpu().numpy() + scores = scores.cpu().numpy() + idxs = scores.argsort()[-args.top_k or None:] + + outpath = img_path + '.' + args.tag + print(f"Saving {len(idxs)} keypoints to {outpath}") + np.savez(open(outpath,'wb'), + imsize = (W,H), + keypoints = xys[idxs], + descriptors = desc[idxs], + scores = scores[idxs]) + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser("Extract keypoints for a given image") + parser.add_argument("--model", type=str, required=True, help='model path') + + parser.add_argument("--images", type=str, required=True, nargs='+', help='images / list') + parser.add_argument("--tag", type=str, default='r2d2', help='output file tag') + + parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints') + + parser.add_argument("--scale-f", type=float, default=2**0.25) + parser.add_argument("--min-size", type=int, default=256) + parser.add_argument("--max-size", type=int, default=1024) + parser.add_argument("--min-scale", type=float, default=0) + parser.add_argument("--max-scale", type=float, default=1) + + parser.add_argument("--reliability-thr", type=float, default=0.7) + parser.add_argument("--repeatability-thr", type=float, default=0.7) + + parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU') + args = parser.parse_args() + + extract_keypoints(args) + diff --git a/imcui/third_party/r2d2/extract_kapture.py b/imcui/third_party/r2d2/extract_kapture.py new file mode 100644 index 0000000000000000000000000000000000000000..51b2403b8a1730eaee32d099d0b6dd5d091ccdda --- /dev/null +++ b/imcui/third_party/r2d2/extract_kapture.py @@ -0,0 +1,194 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + + +from PIL import Image + +from tools import common +from tools.dataloader import norm_RGB +from nets.patchnet import * +from os import path + +from extract import load_network, NonMaxSuppression, extract_multiscale + +# Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) +# and more generally sensor-acquired data +# it can be installed with +# pip install kapture +# for more information check out https://github.com/naver/kapture +import kapture +from kapture.io.records import get_image_fullpath +from kapture.io.csv import kapture_from_dir +from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file +from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file +from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file +from kapture.io.csv import get_all_tar_handlers + + +def extract_kapture_keypoints(args): + """ + Extract r2d2 keypoints and descritors to the kapture format directly + """ + print('extract_kapture_keypoints...') + with get_all_tar_handlers(args.kapture_root, + mode={kapture.Keypoints: 'a', + kapture.Descriptors: 'a', + kapture.GlobalFeatures: 'r', + kapture.Matches: 'r'}) as tar_handlers: + kdata = kapture_from_dir(args.kapture_root, None, + skip_list=[kapture.GlobalFeatures, + kapture.Matches, + kapture.Points3d, + kapture.Observations], + tar_handlers=tar_handlers) + + assert kdata.records_camera is not None + image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)] + if args.keypoints_type is None: + args.keypoints_type = path.splitext(path.basename(args.model))[0] + print(f'keypoints_type set to {args.keypoints_type}') + if args.descriptors_type is None: + args.descriptors_type = path.splitext(path.basename(args.model))[0] + print(f'descriptors_type set to {args.descriptors_type}') + + if kdata.keypoints is not None and args.keypoints_type in kdata.keypoints \ + and kdata.descriptors is not None and args.descriptors_type in kdata.descriptors: + print('detected already computed features of same keypoints_type/descriptors_type, resuming extraction...') + image_list = [name + for name in image_list + if name not in kdata.keypoints[args.keypoints_type] or + name not in kdata.descriptors[args.descriptors_type]] + + if len(image_list) == 0: + print('All features were already extracted') + return + else: + print(f'Extracting r2d2 features for {len(image_list)} images') + + iscuda = common.torch_set_gpu(args.gpu) + + # load the network... + net = load_network(args.model) + if iscuda: + net = net.cuda() + + # create the non-maxima detector + detector = NonMaxSuppression( + rel_thr=args.reliability_thr, + rep_thr=args.repeatability_thr) + + if kdata.keypoints is None: + kdata.keypoints = {} + if kdata.descriptors is None: + kdata.descriptors = {} + + if args.keypoints_type not in kdata.keypoints: + keypoints_dtype = None + keypoints_dsize = None + else: + keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype + keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize + if args.descriptors_type not in kdata.descriptors: + descriptors_dtype = None + descriptors_dsize = None + else: + descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype + descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize + + for image_name in image_list: + img_path = get_image_fullpath(args.kapture_root, image_name) + print(f"\nExtracting features for {img_path}") + img = Image.open(img_path).convert('RGB') + W, H = img.size + img = norm_RGB(img)[None] + if iscuda: + img = img.cuda() + + # extract keypoints/descriptors for a single image + xys, desc, scores = extract_multiscale(net, img, detector, + scale_f=args.scale_f, + min_scale=args.min_scale, + max_scale=args.max_scale, + min_size=args.min_size, + max_size=args.max_size, + verbose=True) + + xys = xys.cpu().numpy() + desc = desc.cpu().numpy() + scores = scores.cpu().numpy() + idxs = scores.argsort()[-args.top_k or None:] + + xys = xys[idxs] + desc = desc[idxs] + if keypoints_dtype is None or descriptors_dtype is None: + keypoints_dtype = xys.dtype + descriptors_dtype = desc.dtype + + keypoints_dsize = xys.shape[1] + descriptors_dsize = desc.shape[1] + + kdata.keypoints[args.keypoints_type] = kapture.Keypoints('r2d2', keypoints_dtype, keypoints_dsize) + kdata.descriptors[args.descriptors_type] = kapture.Descriptors('r2d2', descriptors_dtype, + descriptors_dsize, + args.keypoints_type, 'L2') + keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints, + args.keypoints_type, + args.kapture_root) + descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors, + args.descriptors_type, + args.kapture_root) + keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type]) + descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type]) + else: + assert kdata.keypoints[args.keypoints_type].dtype == xys.dtype + assert kdata.descriptors[args.descriptors_type].dtype == desc.dtype + assert kdata.keypoints[args.keypoints_type].dsize == xys.shape[1] + assert kdata.descriptors[args.descriptors_type].dsize == desc.shape[1] + assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type + assert kdata.descriptors[args.descriptors_type].metric_type == 'L2' + + keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root, + image_name, tar_handlers) + print(f"Saving {xys.shape[0]} keypoints to {keypoints_fullpath}") + image_keypoints_to_file(keypoints_fullpath, xys) + kdata.keypoints[args.keypoints_type].add(image_name) + + descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root, + image_name, tar_handlers) + print(f"Saving {desc.shape[0]} descriptors to {descriptors_fullpath}") + image_descriptors_to_file(descriptors_fullpath, desc) + kdata.descriptors[args.descriptors_type].add(image_name) + + if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type, + args.kapture_root, tar_handlers) or \ + not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type, + args.kapture_root, tar_handlers): + print('local feature extraction ended successfully but not all files were saved') + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser( + "Extract r2d2 local features for all images in a dataset stored in the kapture format") + parser.add_argument("--model", type=str, required=True, help='model path') + parser.add_argument('--keypoints-type', default=None, help='keypoint type_name, default is filename of model') + parser.add_argument('--descriptors-type', default=None, help='descriptors type_name, default is filename of model') + + parser.add_argument("--kapture-root", type=str, required=True, help='path to kapture root directory') + + parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints') + + parser.add_argument("--scale-f", type=float, default=2**0.25) + parser.add_argument("--min-size", type=int, default=256) + parser.add_argument("--max-size", type=int, default=1024) + parser.add_argument("--min-scale", type=float, default=0) + parser.add_argument("--max-scale", type=float, default=1) + + parser.add_argument("--reliability-thr", type=float, default=0.7) + parser.add_argument("--repeatability-thr", type=float, default=0.7) + + parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU') + args = parser.parse_args() + + extract_kapture_keypoints(args) diff --git a/imcui/third_party/r2d2/nets/ap_loss.py b/imcui/third_party/r2d2/nets/ap_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..251815cd97009a5feb6a815c20caca0c40daaccd --- /dev/null +++ b/imcui/third_party/r2d2/nets/ap_loss.py @@ -0,0 +1,67 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +import numpy as np +import torch +import torch.nn as nn + + +class APLoss (nn.Module): + """ differentiable AP loss, through quantization. + + Input: (N, M) values in [min, max] + label: (N, M) values in {0, 1} + + Returns: list of query AP (for each n in {1..N}) + Note: typically, you want to minimize 1 - mean(AP) + """ + def __init__(self, nq=25, min=0, max=1, euc=False): + nn.Module.__init__(self) + assert isinstance(nq, int) and 2 <= nq <= 100 + self.nq = nq + self.min = min + self.max = max + self.euc = euc + gap = max - min + assert gap > 0 + + # init quantizer = non-learnable (fixed) convolution + self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True) + a = (nq-1) / gap + #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1) + q.weight.data[:nq] = -a + q.bias.data[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x) + #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1) + q.weight.data[nq:] = a + q.bias.data[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x) + # first and last one are special: just horizontal straight line + q.weight.data[0] = q.weight.data[-1] = 0 + q.bias.data[0] = q.bias.data[-1] = 1 + + def compute_AP(self, x, label): + N, M = x.shape + if self.euc: # euclidean distance in same range than similarities + x = 1 - torch.sqrt(2.001 - 2*x) + + # quantize all predictions + q = self.quantizer(x.unsqueeze(1)) + q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M + + nbs = q.sum(dim=-1) # number of samples N x Q = c + rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q + prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision + rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1] + + ap = (prec * rec).sum(dim=-1) # per-image AP + return ap + + def forward(self, x, label): + assert x.shape == label.shape # N x M + return self.compute_AP(x, label) + + + + + diff --git a/imcui/third_party/r2d2/nets/losses.py b/imcui/third_party/r2d2/nets/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..f8eea8f6e82835e22d2bb445125f7dc722db85b2 --- /dev/null +++ b/imcui/third_party/r2d2/nets/losses.py @@ -0,0 +1,56 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nets.sampler import * +from nets.repeatability_loss import * +from nets.reliability_loss import * + + +class MultiLoss (nn.Module): + """ Combines several loss functions for convenience. + *args: [loss weight (float), loss creator, ... ] + + Example: + loss = MultiLoss( 1, MyFirstLoss(), 0.5, MySecondLoss() ) + """ + def __init__(self, *args, dbg=()): + nn.Module.__init__(self) + assert len(args) % 2 == 0, 'args must be a list of (float, loss)' + self.weights = [] + self.losses = nn.ModuleList() + for i in range(len(args)//2): + weight = float(args[2*i+0]) + loss = args[2*i+1] + assert isinstance(loss, nn.Module), "%s is not a loss!" % loss + self.weights.append(weight) + self.losses.append(loss) + + def forward(self, select=None, **variables): + assert not select or all(1<=n<=len(self.losses) for n in select) + d = dict() + cum_loss = 0 + for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses),1): + if select is not None and num not in select: continue + l = loss_func(**{k:v for k,v in variables.items()}) + if isinstance(l, tuple): + assert len(l) == 2 and isinstance(l[1], dict) + else: + l = l, {loss_func.name:l} + cum_loss = cum_loss + weight * l[0] + for key,val in l[1].items(): + d['loss_'+key] = float(val) + d['loss'] = float(cum_loss) + return cum_loss, d + + + + + + diff --git a/imcui/third_party/r2d2/nets/patchnet.py b/imcui/third_party/r2d2/nets/patchnet.py new file mode 100644 index 0000000000000000000000000000000000000000..854c61ecf9b879fa7f420255296c4fbbfd665181 --- /dev/null +++ b/imcui/third_party/r2d2/nets/patchnet.py @@ -0,0 +1,186 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BaseNet (nn.Module): + """ Takes a list of images as input, and returns for each image: + - a pixelwise descriptor + - a pixelwise confidence + """ + def softmax(self, ux): + if ux.shape[1] == 1: + x = F.softplus(ux) + return x / (1 + x) # for sure in [0,1], much less plateaus than softmax + elif ux.shape[1] == 2: + return F.softmax(ux, dim=1)[:,1:2] + + def normalize(self, x, ureliability, urepeatability): + return dict(descriptors = F.normalize(x, p=2, dim=1), + repeatability = self.softmax( urepeatability ), + reliability = self.softmax( ureliability )) + + def forward_one(self, x): + raise NotImplementedError() + + def forward(self, imgs, **kw): + res = [self.forward_one(img) for img in imgs] + # merge all dictionaries into one + res = {k:[r[k] for r in res if k in r] for k in {k for r in res for k in r}} + return dict(res, imgs=imgs, **kw) + + + +class PatchNet (BaseNet): + """ Helper class to construct a fully-convolutional network that + extract a l2-normalized patch descriptor. + """ + def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): + BaseNet.__init__(self) + self.inchan = inchan + self.curchan = inchan + self.dilated = dilated + self.dilation = dilation + self.bn = bn + self.bn_affine = bn_affine + self.ops = nn.ModuleList([]) + + def _make_bn(self, outd): + return nn.BatchNorm2d(outd, affine=self.bn_affine) + + def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max'): + # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer + d = self.dilation * dilation + if self.dilated: + conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1) + self.dilation *= stride + else: + conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride) + self.ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) ) + if bn and self.bn: self.ops.append( self._make_bn(outd) ) + if relu: self.ops.append( nn.ReLU(inplace=True) ) + self.curchan = outd + + if k_pool > 1: + if pool_type == 'avg': + self.ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) + elif pool_type == 'max': + self.ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) + else: + print(f"Error, unknown pooling type {pool_type}...") + + def forward_one(self, x): + assert self.ops, "You need to add convolutions first" + for n,op in enumerate(self.ops): + x = op(x) + return self.normalize(x) + + +class L2_Net (PatchNet): + """ Compute a 128D descriptor for all overlapping 32x32 patches. + From the L2Net paper (CVPR'17). + """ + def __init__(self, dim=128, **kw ): + PatchNet.__init__(self, **kw) + add_conv = lambda n,**kw: self._add_conv((n*dim)//128,**kw) + add_conv(32) + add_conv(32) + add_conv(64, stride=2) + add_conv(64) + add_conv(128, stride=2) + add_conv(128) + add_conv(128, k=7, stride=8, bn=False, relu=False) + self.out_dim = dim + + +class Quad_L2Net (PatchNet): + """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs. + """ + def __init__(self, dim=128, mchan=4, relu22=False, **kw ): + PatchNet.__init__(self, **kw) + self._add_conv( 8*mchan) + self._add_conv( 8*mchan) + self._add_conv( 16*mchan, stride=2) + self._add_conv( 16*mchan) + self._add_conv( 32*mchan, stride=2) + self._add_conv( 32*mchan) + # replace last 8x8 convolution with 3 2x2 convolutions + self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) + self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) + self._add_conv(dim, k=2, stride=2, bn=False, relu=False) + self.out_dim = dim + + + +class Quad_L2Net_ConfCFS (Quad_L2Net): + """ Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability. + """ + def __init__(self, **kw ): + Quad_L2Net.__init__(self, **kw) + # reliability classifier + self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1) + # repeatability classifier: for some reasons it's a softplus, not a softmax! + # Why? I guess it's a mistake that was left unnoticed in the code for a long time... + self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) + + def forward_one(self, x): + assert self.ops, "You need to add convolutions first" + for op in self.ops: + x = op(x) + # compute the confidence maps + ureliability = self.clf(x**2) + urepeatability = self.sal(x**2) + return self.normalize(x, ureliability, urepeatability) + + +class Fast_Quad_L2Net (PatchNet): + """ Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time + Dilation factors and pooling: + 1,1,1, pool2, 1,1, 2,2, 4, 8, upsample2 + """ + def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw ): + + PatchNet.__init__(self, **kw) + self._add_conv( 8*mchan) + self._add_conv( 8*mchan) + self._add_conv( 16*mchan, k_pool = downsample_factor) # added avg pooling to decrease img resolution + self._add_conv( 16*mchan) + self._add_conv( 32*mchan, stride=2) + self._add_conv( 32*mchan) + + # replace last 8x8 convolution with 3 2x2 convolutions + self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) + self._add_conv( 32*mchan, k=2, stride=2, relu=relu22) + self._add_conv(dim, k=2, stride=2, bn=False, relu=False) + + # Go back to initial image resolution with upsampling + self.ops.append(torch.nn.Upsample(scale_factor=downsample_factor, mode='bilinear', align_corners=False)) + + self.out_dim = dim + + +class Fast_Quad_L2Net_ConfCFS (Fast_Quad_L2Net): + """ Fast r2d2 architecture + """ + def __init__(self, **kw ): + Fast_Quad_L2Net.__init__(self, **kw) + # reliability classifier + self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1) + + # repeatability classifier: for some reasons it's a softplus, not a softmax! + # Why? I guess it's a mistake that was left unnoticed in the code for a long time... + self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) + + def forward_one(self, x): + assert self.ops, "You need to add convolutions first" + for op in self.ops: + x = op(x) + # compute the confidence maps + ureliability = self.clf(x**2) + urepeatability = self.sal(x**2) + return self.normalize(x, ureliability, urepeatability) \ No newline at end of file diff --git a/imcui/third_party/r2d2/nets/reliability_loss.py b/imcui/third_party/r2d2/nets/reliability_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..52d5383b0eaa52bcf2111eabb4b45e39b63b976f --- /dev/null +++ b/imcui/third_party/r2d2/nets/reliability_loss.py @@ -0,0 +1,59 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +import torch.nn as nn +import torch.nn.functional as F + +from nets.ap_loss import APLoss + + +class PixelAPLoss (nn.Module): + """ Computes the pixel-wise AP loss: + Given two images and ground-truth optical flow, computes the AP per pixel. + + feat1: (B, C, H, W) pixel-wise features extracted from img1 + feat2: (B, C, H, W) pixel-wise features extracted from img2 + aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 + """ + def __init__(self, sampler, nq=20): + nn.Module.__init__(self) + self.aploss = APLoss(nq, min=0, max=1, euc=False) + self.name = 'pixAP' + self.sampler = sampler + + def loss_from_ap(self, ap, rel): + return 1 - ap + + def forward(self, descriptors, aflow, **kw): + # subsample things + scores, gt, msk, qconf = self.sampler(descriptors, kw.get('reliability'), aflow) + + # compute pixel-wise AP + n = qconf.numel() + if n == 0: return 0 + scores, gt = scores.view(n,-1), gt.view(n,-1) + ap = self.aploss(scores, gt).view(msk.shape) + + pixel_loss = self.loss_from_ap(ap, qconf) + + loss = pixel_loss[msk].mean() + return loss + + +class ReliabilityLoss (PixelAPLoss): + """ same than PixelAPLoss, but also train a pixel-wise confidence + that this pixel is going to have a good AP. + """ + def __init__(self, sampler, base=0.5, **kw): + PixelAPLoss.__init__(self, sampler, **kw) + assert 0 <= base < 1 + self.base = base + self.name = 'reliability' + + def loss_from_ap(self, ap, rel): + return 1 - ap*rel - (1-rel)*self.base + + + diff --git a/imcui/third_party/r2d2/nets/repeatability_loss.py b/imcui/third_party/r2d2/nets/repeatability_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5cda0b6d036f98af88a88780fe39da0c5c0b610e --- /dev/null +++ b/imcui/third_party/r2d2/nets/repeatability_loss.py @@ -0,0 +1,66 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nets.sampler import FullSampler + +class CosimLoss (nn.Module): + """ Try to make the repeatability repeatable from one image to the other. + """ + def __init__(self, N=16): + nn.Module.__init__(self) + self.name = f'cosim{N}' + self.patches = nn.Unfold(N, padding=0, stride=N//2) + + def extract_patches(self, sal): + patches = self.patches(sal).transpose(1,2) # flatten + patches = F.normalize(patches, p=2, dim=2) # norm + return patches + + def forward(self, repeatability, aflow, **kw): + B,two,H,W = aflow.shape + assert two == 2 + + # normalize + sali1, sali2 = repeatability + grid = FullSampler._aflow_to_grid(aflow) + sali2 = F.grid_sample(sali2, grid, mode='bilinear', padding_mode='border') + + patches1 = self.extract_patches(sali1) + patches2 = self.extract_patches(sali2) + cosim = (patches1 * patches2).sum(dim=2) + return 1 - cosim.mean() + + +class PeakyLoss (nn.Module): + """ Try to make the repeatability locally peaky. + + Mechanism: we maximize, for each pixel, the difference between the local mean + and the local max. + """ + def __init__(self, N=16): + nn.Module.__init__(self) + self.name = f'peaky{N}' + assert N % 2 == 0, 'N must be pair' + self.preproc = nn.AvgPool2d(3, stride=1, padding=1) + self.maxpool = nn.MaxPool2d(N+1, stride=1, padding=N//2) + self.avgpool = nn.AvgPool2d(N+1, stride=1, padding=N//2) + + def forward_one(self, sali): + sali = self.preproc(sali) # remove super high frequency + return 1 - (self.maxpool(sali) - self.avgpool(sali)).mean() + + def forward(self, repeatability, **kw): + sali1, sali2 = repeatability + return (self.forward_one(sali1) + self.forward_one(sali2)) /2 + + + + + diff --git a/imcui/third_party/r2d2/nets/sampler.py b/imcui/third_party/r2d2/nets/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9fede70d3a04d7f31a1d414eace0aaf3729e8235 --- /dev/null +++ b/imcui/third_party/r2d2/nets/sampler.py @@ -0,0 +1,390 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +""" Different samplers, each specifying how to sample pixels for the AP loss. +""" + + +class FullSampler(nn.Module): + """ all pixels are selected + - feats: keypoint descriptors + - confs: reliability values + """ + def __init__(self): + nn.Module.__init__(self) + self.mode = 'bilinear' + self.padding = 'zeros' + + @staticmethod + def _aflow_to_grid(aflow): + H, W = aflow.shape[2:] + grid = aflow.permute(0,2,3,1).clone() + grid[:,:,:,0] *= 2/(W-1) + grid[:,:,:,1] *= 2/(H-1) + grid -= 1 + grid[torch.isnan(grid)] = 9e9 # invalids + return grid + + def _warp(self, feats, confs, aflow): + if isinstance(aflow, tuple): return aflow # result was precomputed + feat1, feat2 = feats + conf1, conf2 = confs if confs else (None,None) + + B, two, H, W = aflow.shape + D = feat1.shape[1] + assert feat1.shape == feat2.shape == (B, D, H, W) # D = 128, B = batch + assert conf1.shape == conf2.shape == (B, 1, H, W) if confs else True + + # warp img2 to img1 + grid = self._aflow_to_grid(aflow) + ones2 = feat2.new_ones(feat2[:,0:1].shape) + feat2to1 = F.grid_sample(feat2, grid, mode=self.mode, padding_mode=self.padding) + mask2to1 = F.grid_sample(ones2, grid, mode='nearest', padding_mode='zeros') + conf2to1 = F.grid_sample(conf2, grid, mode=self.mode, padding_mode=self.padding) \ + if confs else None + return feat2to1, mask2to1.byte(), conf2to1 + + def _warp_positions(self, aflow): + B, two, H, W = aflow.shape + assert two == 2 + + Y = torch.arange(H, device=aflow.device) + X = torch.arange(W, device=aflow.device) + XY = torch.stack(torch.meshgrid(Y,X)[::-1], dim=0) + XY = XY[None].expand(B, 2, H, W).float() + + grid = self._aflow_to_grid(aflow) + XY2 = F.grid_sample(XY, grid, mode='bilinear', padding_mode='zeros') + return XY, XY2 + + + +class SubSampler (FullSampler): + """ pixels are selected in an uniformly spaced grid + """ + def __init__(self, border, subq, subd, perimage=False): + FullSampler.__init__(self) + assert subq % subd == 0, 'subq must be multiple of subd' + self.sub_q = subq + self.sub_d = subd + self.border = border + self.perimage = perimage + + def __repr__(self): + return "SubSampler(border=%d, subq=%d, subd=%d, perimage=%d)" % ( + self.border, self.sub_q, self.sub_d, self.perimage) + + def __call__(self, feats, confs, aflow): + feat1, conf1 = feats[0], (confs[0] if confs else None) + # warp with optical flow in img1 coords + feat2, mask2, conf2 = self._warp(feats, confs, aflow) + + # subsample img1 + slq = slice(self.border, -self.border or None, self.sub_q) + feat1 = feat1[:, :, slq, slq] + conf1 = conf1[:, :, slq, slq] if confs else None + # subsample img2 + sld = slice(self.border, -self.border or None, self.sub_d) + feat2 = feat2[:, :, sld, sld] + mask2 = mask2[:, :, sld, sld] + conf2 = conf2[:, :, sld, sld] if confs else None + + B, D, Hq, Wq = feat1.shape + B, D, Hd, Wd = feat2.shape + + # compute gt + if self.perimage or self.sub_q != self.sub_d: + # compute ground-truth by comparing pixel indices + f = feats[0][0:1,0] if self.perimage else feats[0][:,0] + idxs = torch.arange(f.numel(), dtype=torch.int64, device=feat1.device).view(f.shape) + idxs1 = idxs[:, slq, slq].reshape(-1,Hq*Wq) + idxs2 = idxs[:, sld, sld].reshape(-1,Hd*Wd) + if self.perimage: + gt = (idxs1[0].view(-1,1) == idxs2[0].view(1,-1)) + gt = gt[None,:,:].expand(B, Hq*Wq, Hd*Wd) + else : + gt = (idxs1.view(-1,1) == idxs2.view(1,-1)) + else: + gt = torch.eye(feat1[:,0].numel(), dtype=torch.uint8, device=feat1.device) # always binary for AP loss + + # compute all images together + queries = feat1.reshape(B,D,-1) # B x D x (Hq x Wq) + database = feat2.reshape(B,D,-1) # B x D x (Hd x Wd) + if self.perimage: + queries = queries.transpose(1,2) # B x (Hd x Wd) x D + scores = torch.bmm(queries, database) # B x (Hq x Wq) x (Hd x Wd) + else: + queries = queries .transpose(1,2).reshape(-1,D) # (B x Hq x Wq) x D + database = database.transpose(1,0).reshape(D,-1) # D x (B x Hd x Wd) + scores = torch.matmul(queries, database) # (B x Hq x Wq) x (B x Hd x Wd) + + # compute reliability + qconf = (conf1 + conf2)/2 if confs else None + + assert gt.shape == scores.shape + return scores, gt, mask2, qconf + + + +class NghSampler (FullSampler): + """ all pixels in a small neighborhood + """ + def __init__(self, ngh, subq=1, subd=1, ignore=1, border=None): + FullSampler.__init__(self) + assert 0 <= ignore < ngh + self.ngh = ngh + self.ignore = ignore + assert subd <= ngh + self.sub_q = subq + self.sub_d = subd + if border is None: border = ngh + assert border >= ngh, 'border has to be larger than ngh' + self.border = border + + def __repr__(self): + return "NghSampler(ngh=%d, subq=%d, subd=%d, ignore=%d, border=%d)" % ( + self.ngh, self.sub_q, self.sub_d, self.ignore, self.border) + + def trans(self, arr, i, j): + s = lambda i: slice(self.border+i, i-self.border or None, self.sub_q) + return arr[:,:,s(j),s(i)] + + def __call__(self, feats, confs, aflow): + feat1, conf1 = feats[0], (confs[0] if confs else None) + # warp with optical flow in img1 coords + feat2, mask2, conf2 = self._warp(feats, confs, aflow) + + qfeat = self.trans(feat1,0,0) + qconf = (self.trans(conf1,0,0) + self.trans(conf2,0,0)) / 2 if confs else None + mask2 = self.trans(mask2,0,0) + scores_at = lambda i,j: (qfeat * self.trans(feat2,i,j)).sum(dim=1) + + # compute scores for all neighbors + B, D = feat1.shape[:2] + min_d = self.ignore**2 + max_d = self.ngh**2 + rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + negs = [] + offsets = [] + for j in range(-rad, rad+1, self.sub_d): + for i in range(-rad, rad+1, self.sub_d): + if not(min_d < i*i + j*j <= max_d): + continue # out of scope + offsets.append((i,j)) # Note: this list is just for debug + negs.append( scores_at(i,j) ) + + scores = torch.stack([scores_at(0,0)] + negs, dim=-1) + gt = scores.new_zeros(scores.shape, dtype=torch.uint8) + gt[..., 0] = 1 # only the center point is positive + + return scores, gt, mask2, qconf + + + +class FarNearSampler (FullSampler): + """ Sample pixels from *both* a small neighborhood *and* far-away pixels. + + How it works? + 1) Queries are sampled from img1, + - at least `border` pixels from borders and + - on a grid with step = `subq` + + 2) Close database pixels + - from the corresponding image (img2), + - within a `ngh` distance radius + - on a grid with step = `subd_ngh` + - ignored if distance to query is >0 and <=`ignore` + + 3) Far-away database pixels from , + - from all batch images in `img2` + - at least `border` pixels from borders + - on a grid with step = `subd_far` + """ + def __init__(self, subq, ngh, subd_ngh, subd_far, border=None, ignore=1, + maxpool_ngh=False ): + FullSampler.__init__(self) + border = border or ngh + assert ignore < ngh < subd_far, 'neighborhood needs to be smaller than far step' + self.close_sampler = NghSampler(ngh=ngh, subq=subq, subd=subd_ngh, + ignore=not(maxpool_ngh), border=border) + self.faraway_sampler = SubSampler(border=border, subq=subq, subd=subd_far) + self.maxpool_ngh = maxpool_ngh + + def __repr__(self): + c,f = self.close_sampler, self.faraway_sampler + res = "FarNearSampler(subq=%d, ngh=%d" % (c.sub_q, c.ngh) + res += ", subd_ngh=%d, subd_far=%d" % (c.sub_d, f.sub_d) + res += ", border=%d, ign=%d" % (f.border, c.ignore) + res += ", maxpool_ngh=%d" % self.maxpool_ngh + return res+')' + + def __call__(self, feats, confs, aflow): + # warp with optical flow in img1 coords + aflow = self._warp(feats, confs, aflow) + + # sample ngh pixels + scores1, gt1, msk1, conf1 = self.close_sampler(feats, confs, aflow) + scores1, gt1 = scores1.view(-1,scores1.shape[-1]), gt1.view(-1,gt1.shape[-1]) + if self.maxpool_ngh: + # we consider all scores from ngh as potential positives + scores1, self._cached_maxpool_ngh = scores1.max(dim=1,keepdim=True) + gt1 = gt1[:, 0:1] + + # sample far pixels + scores2, gt2, msk2, conf2 = self.faraway_sampler(feats, confs, aflow) + # assert (msk1 == msk2).all() + # assert (conf1 == conf2).all() + + return (torch.cat((scores1,scores2),dim=1), + torch.cat((gt1, gt2), dim=1), + msk1, conf1 if confs else None) + + +class NghSampler2 (nn.Module): + """ Similar to NghSampler, but doesnt warp the 2nd image. + Distance to GT => 0 ... pos_d ... neg_d ... ngh + Pixel label => + + + + + + 0 0 - - - - - - - + + Subsample on query side: if > 0, regular grid + < 0, random points + In both cases, the number of query points is = W*H/subq**2 + """ + def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None, + maxpool_pos=True, subd_neg=0): + nn.Module.__init__(self) + assert 0 <= pos_d < neg_d <= (ngh if ngh else 99) + self.ngh = ngh + self.pos_d = pos_d + self.neg_d = neg_d + assert subd <= ngh or ngh == 0 + assert subq != 0 + self.sub_q = subq + self.sub_d = subd + self.sub_d_neg = subd_neg + if border is None: border = ngh + assert border >= ngh, 'border has to be larger than ngh' + self.border = border + self.maxpool_pos = maxpool_pos + self.precompute_offsets() + + def precompute_offsets(self): + pos_d2 = self.pos_d**2 + neg_d2 = self.neg_d**2 + rad2 = self.ngh**2 + rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple + pos = [] + neg = [] + for j in range(-rad, rad+1, self.sub_d): + for i in range(-rad, rad+1, self.sub_d): + d2 = i*i + j*j + if d2 <= pos_d2: + pos.append( (i,j) ) + elif neg_d2 <= d2 <= rad2: + neg.append( (i,j) ) + + self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t()) + self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t()) + + def gen_grid(self, step, aflow): + B, two, H, W = aflow.shape + dev = aflow.device + b1 = torch.arange(B, device=dev) + if step > 0: + # regular grid + x1 = torch.arange(self.border, W-self.border, step, device=dev) + y1 = torch.arange(self.border, H-self.border, step, device=dev) + H1, W1 = len(y1), len(x1) + x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1) + y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1) + b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1) + shape = (B, H1, W1) + else: + # randomly spread + n = (H - 2*self.border) * (W - 2*self.border) // step**2 + x1 = torch.randint(self.border, W-self.border, (n,), device=dev) + y1 = torch.randint(self.border, H-self.border, (n,), device=dev) + x1 = x1[None,:].expand(B,n).reshape(-1) + y1 = y1[None,:].expand(B,n).reshape(-1) + b1 = b1[:,None].expand(B,n).reshape(-1) + shape = (B, n) + return b1, y1, x1, shape + + def forward(self, feats, confs, aflow, **kw): + B, two, H, W = aflow.shape + assert two == 2 + feat1, conf1 = feats[0], (confs[0] if confs else None) + feat2, conf2 = feats[1], (confs[1] if confs else None) + + # positions in the first image + b1, y1, x1, shape = self.gen_grid(self.sub_q, aflow) + + # sample features from first image + feat1 = feat1[b1, :, y1, x1] + qconf = conf1[b1, :, y1, x1].view(shape) if confs else None + + #sample GT from second image + b2 = b1 + xy2 = (aflow[b1, :, y1, x1] + 0.5).long().t() + mask = (0 <= xy2[0]) * (0 <= xy2[1]) * (xy2[0] < W) * (xy2[1] < H) + mask = mask.view(shape) + + def clamp(xy): + torch.clamp(xy[0], 0, W-1, out=xy[0]) + torch.clamp(xy[1], 0, H-1, out=xy[1]) + return xy + + # compute positive scores + xy2p = clamp(xy2[:,None,:] + self.pos_offsets[:,:,None]) + pscores = (feat1[None,:,:] * feat2[b2, :, xy2p[1], xy2p[0]]).sum(dim=-1).t() +# xy1p = clamp(torch.stack((x1,y1))[:,None,:] + self.pos_offsets[:,:,None]) +# grid = FullSampler._aflow_to_grid(aflow) +# feat2p = F.grid_sample(feat2, grid, mode='bilinear', padding_mode='border') +# pscores = (feat1[None,:,:] * feat2p[b1,:,xy1p[1], xy1p[0]]).sum(dim=-1).t() + if self.maxpool_pos: + pscores, pos = pscores.max(dim=1, keepdim=True) + if confs: + sel = clamp(xy2 + self.pos_offsets[:,pos.view(-1)]) + qconf = (qconf + conf2[b2, :, sel[1], sel[0]].view(shape))/2 + + # compute negative scores + xy2n = clamp(xy2[:,None,:] + self.neg_offsets[:,:,None]) + nscores = (feat1[None,:,:] * feat2[b2, :, xy2n[1], xy2n[0]]).sum(dim=-1).t() + + if self.sub_d_neg: + # add distractors from a grid + b3, y3, x3, _ = self.gen_grid(self.sub_d_neg, aflow) + distractors = feat2[b3, :, y3, x3] + dscores = torch.matmul(feat1, distractors.t()) + del distractors + + # remove scores that corresponds to positives or nulls + dis2 = (x3 - xy2[0][:,None])**2 + (y3 - xy2[1][:,None])**2 + dis2 += (b3 != b2[:,None]).long() * self.neg_d**2 + dscores[dis2 < self.neg_d**2] = 0 + + scores = torch.cat((pscores, nscores, dscores), dim=1) + else: + # concat everything + scores = torch.cat((pscores, nscores), dim=1) + + gt = scores.new_zeros(scores.shape, dtype=torch.uint8) + gt[:, :pscores.shape[1]] = 1 + + return scores, gt, mask, qconf + + + + + + + + diff --git a/imcui/third_party/r2d2/tools/common.py b/imcui/third_party/r2d2/tools/common.py new file mode 100644 index 0000000000000000000000000000000000000000..a7875ddd714b1d08efb0d1369c3a856490796288 --- /dev/null +++ b/imcui/third_party/r2d2/tools/common.py @@ -0,0 +1,41 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import os, pdb#, shutil +import numpy as np +import torch + + +def mkdir_for(file_path): + os.makedirs(os.path.split(file_path)[0], exist_ok=True) + + +def model_size(model): + ''' Computes the number of parameters of the model + ''' + size = 0 + for weights in model.state_dict().values(): + size += np.prod(weights.shape) + return size + + +def torch_set_gpu(gpus): + if type(gpus) is int: + gpus = [gpus] + + cuda = all(gpu>=0 for gpu in gpus) + + if cuda: + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus]) + assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % ( + os.environ['HOSTNAME'],os.environ['CUDA_VISIBLE_DEVICES']) + torch.backends.cudnn.benchmark = True # speed-up cudnn + torch.backends.cudnn.fastest = True # even more speed-up? + print( 'Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'] ) + + else: + print( 'Launching on CPU' ) + + return cuda + diff --git a/imcui/third_party/r2d2/tools/dataloader.py b/imcui/third_party/r2d2/tools/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d9fff5f8dfb8d9d3b243a57555779de33d0818 --- /dev/null +++ b/imcui/third_party/r2d2/tools/dataloader.py @@ -0,0 +1,367 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +from PIL import Image +import numpy as np + +import torch +import torchvision.transforms as tvf + +from tools.transforms import instanciate_transformation +from tools.transforms_tools import persp_apply + + +RGB_mean = [0.485, 0.456, 0.406] +RGB_std = [0.229, 0.224, 0.225] + +norm_RGB = tvf.Compose([tvf.ToTensor(), tvf.Normalize(mean=RGB_mean, std=RGB_std)]) + + +class PairLoader: + """ On-the-fly jittering of pairs of image with dense pixel ground-truth correspondences. + + crop: random crop applied to both images + scale: random scaling applied to img2 + distort: random ditorsion applied to img2 + + self[idx] returns a dictionary with keys: img1, img2, aflow, mask + - img1: cropped original + - img2: distorted cropped original + - aflow: 'absolute' optical flow = (x,y) position of each pixel from img1 in img2 + - mask: (binary image) valid pixels of img1 + """ + def __init__(self, dataset, crop='', scale='', distort='', norm = norm_RGB, + what = 'aflow mask', idx_as_rng_seed = False): + assert hasattr(dataset, 'npairs') + assert hasattr(dataset, 'get_pair') + self.dataset = dataset + self.distort = instanciate_transformation(distort) + self.crop = instanciate_transformation(crop) + self.norm = instanciate_transformation(norm) + self.scale = instanciate_transformation(scale) + self.idx_as_rng_seed = idx_as_rng_seed # to remove randomness + self.what = what.split() if isinstance(what, str) else what + self.n_samples = 5 # number of random trials per image + + def __len__(self): + assert len(self.dataset) == self.dataset.npairs, pdb.set_trace() # and not nimg + return len(self.dataset) + + def __repr__(self): + fmt_str = 'PairLoader\n' + fmt_str += repr(self.dataset) + fmt_str += ' npairs: %d\n' % self.dataset.npairs + short_repr = lambda s: repr(s).strip().replace('\n',', ')[14:-1].replace(' ',' ') + fmt_str += ' Distort: %s\n' % short_repr(self.distort) + fmt_str += ' Crop: %s\n' % short_repr(self.crop) + fmt_str += ' Norm: %s\n' % short_repr(self.norm) + return fmt_str + + def __getitem__(self, i): + #from time import time as now; t0 = now() + if self.idx_as_rng_seed: + import random + random.seed(i) + np.random.seed(i) + + # Retrieve an image pair and their absolute flow + img_a, img_b, metadata = self.dataset.get_pair(i, self.what) + + # aflow contains pixel coordinates indicating where each + # pixel from the left image ended up in the right image + # as (x,y) pairs, but its shape is (H,W,2) + aflow = np.float32(metadata['aflow']) + mask = metadata.get('mask', np.ones(aflow.shape[:2],np.uint8)) + + # apply transformations to the second image + img_b = {'img': img_b, 'persp':(1,0,0,0,1,0,0,0)} + if self.scale: + img_b = self.scale(img_b) + if self.distort: + img_b = self.distort(img_b) + + # apply the same transformation to the flow + aflow[:] = persp_apply(img_b['persp'], aflow.reshape(-1,2)).reshape(aflow.shape) + corres = None + if 'corres' in metadata: + corres = np.float32(metadata['corres']) + corres[:,1] = persp_apply(img_b['persp'], corres[:,1]) + + # apply the same transformation to the homography + homography = None + if 'homography' in metadata: + homography = np.float32(metadata['homography']) + # p_b = homography * p_a + persp = np.float32(img_b['persp']+(1,)).reshape(3,3) + homography = persp @ homography + + # determine crop size + img_b = img_b['img'] + crop_size = self.crop({'imsize':(10000,10000)})['imsize'] + output_size_a = min(img_a.size, crop_size) + output_size_b = min(img_b.size, crop_size) + img_a = np.array(img_a) + img_b = np.array(img_b) + + ah,aw,p1 = img_a.shape + bh,bw,p2 = img_b.shape + assert p1 == 3 + assert p2 == 3 + assert aflow.shape == (ah, aw, 2) + assert mask.shape == (ah, aw) + + # Let's start by computing the scale of the + # optical flow and applying a median filter: + dx = np.gradient(aflow[:,:,0]) + dy = np.gradient(aflow[:,:,1]) + scale = np.sqrt(np.clip(np.abs(dx[1]*dy[0] - dx[0]*dy[1]), 1e-16, 1e16)) + + accu2 = np.zeros((16,16), bool) + Q = lambda x, w: np.int32(16 * (x - w.start) / (w.stop - w.start)) + + def window1(x, size, w): + l = x - int(0.5 + size / 2) + r = l + int(0.5 + size) + if l < 0: l,r = (0, r - l) + if r > w: l,r = (l + w - r, w) + if l < 0: l,r = 0,w # larger than width + return slice(l,r) + def window(cx, cy, win_size, scale, img_shape): + return (window1(cy, win_size[1]*scale, img_shape[0]), + window1(cx, win_size[0]*scale, img_shape[1])) + + n_valid_pixel = mask.sum() + sample_w = mask / (1e-16 + n_valid_pixel) + def sample_valid_pixel(): + n = np.random.choice(sample_w.size, p=sample_w.ravel()) + y, x = np.unravel_index(n, sample_w.shape) + return x, y + + # Find suitable left and right windows + trials = 0 # take the best out of few trials + best = -np.inf, None + for _ in range(50*self.n_samples): + if trials >= self.n_samples: break # finished! + + # pick a random valid point from the first image + if n_valid_pixel == 0: break + c1x, c1y = sample_valid_pixel() + + # Find in which position the center of the left + # window ended up being placed in the right image + c2x, c2y = (aflow[c1y, c1x] + 0.5).astype(np.int32) + if not(0 <= c2x < bw and 0 <= c2y < bh): continue + + # Get the flow scale + sigma = scale[c1y, c1x] + + # Determine sampling windows + if 0.2 < sigma < 1: + win1 = window(c1x, c1y, output_size_a, 1/sigma, img_a.shape) + win2 = window(c2x, c2y, output_size_b, 1, img_b.shape) + elif 1 <= sigma < 5: + win1 = window(c1x, c1y, output_size_a, 1, img_a.shape) + win2 = window(c2x, c2y, output_size_b, sigma, img_b.shape) + else: + continue # bad scale + + # compute a score based on the flow + x2,y2 = aflow[win1].reshape(-1, 2).T.astype(np.int32) + # Check the proportion of valid flow vectors + valid = (win2[1].start <= x2) & (x2 < win2[1].stop) \ + & (win2[0].start <= y2) & (y2 < win2[0].stop) + score1 = (valid * mask[win1].ravel()).mean() + # check the coverage of the second window + accu2[:] = False + accu2[Q(y2[valid],win2[0]), Q(x2[valid],win2[1])] = True + score2 = accu2.mean() + # Check how many hits we got + score = min(score1, score2) + + trials += 1 + if score > best[0]: + best = score, win1, win2 + + if None in best: # counldn't find a good window + img_a = np.zeros(output_size_a[::-1]+(3,), dtype=np.uint8) + img_b = np.zeros(output_size_b[::-1]+(3,), dtype=np.uint8) + aflow = np.nan * np.ones((2,)+output_size_a[::-1], dtype=np.float32) + homography = np.nan * np.ones((3,3), dtype=np.float32) + + else: + win1, win2 = best[1:] + img_a = img_a[win1] + img_b = img_b[win2] + aflow = aflow[win1] - np.float32([[[win2[1].start, win2[0].start]]]) + mask = mask[win1] + aflow[~mask.view(bool)] = np.nan # mask bad pixels! + aflow = aflow.transpose(2,0,1) # --> (2,H,W) + + if corres is not None: + corres[:,0] -= (win1[1].start, win1[0].start) + corres[:,1] -= (win2[1].start, win2[0].start) + + if homography is not None: + trans1 = np.eye(3, dtype=np.float32) + trans1[:2,2] = (win1[1].start, win1[0].start) + trans2 = np.eye(3, dtype=np.float32) + trans2[:2,2] = (-win2[1].start, -win2[0].start) + homography = trans2 @ homography @ trans1 + homography /= homography[2,2] + + # rescale if necessary + if img_a.shape[:2][::-1] != output_size_a: + sx, sy = (np.float32(output_size_a)-1)/(np.float32(img_a.shape[:2][::-1])-1) + img_a = np.asarray(Image.fromarray(img_a).resize(output_size_a, Image.ANTIALIAS)) + mask = np.asarray(Image.fromarray(mask).resize(output_size_a, Image.NEAREST)) + afx = Image.fromarray(aflow[0]).resize(output_size_a, Image.NEAREST) + afy = Image.fromarray(aflow[1]).resize(output_size_a, Image.NEAREST) + aflow = np.stack((np.float32(afx), np.float32(afy))) + + if corres is not None: + corres[:,0] *= (sx, sy) + + if homography is not None: + homography = homography @ np.diag(np.float32([1/sx,1/sy,1])) + homography /= homography[2,2] + + if img_b.shape[:2][::-1] != output_size_b: + sx, sy = (np.float32(output_size_b)-1)/(np.float32(img_b.shape[:2][::-1])-1) + img_b = np.asarray(Image.fromarray(img_b).resize(output_size_b, Image.ANTIALIAS)) + aflow *= [[[sx]], [[sy]]] + + if corres is not None: + corres[:,1] *= (sx, sy) + + if homography is not None: + homography = np.diag(np.float32([sx,sy,1])) @ homography + homography /= homography[2,2] + + assert aflow.dtype == np.float32, pdb.set_trace() + assert homography is None or homography.dtype == np.float32, pdb.set_trace() + if 'flow' in self.what: + H, W = img_a.shape[:2] + mgrid = np.mgrid[0:H, 0:W][::-1].astype(np.float32) + flow = aflow - mgrid + + result = dict(img1=self.norm(img_a), img2=self.norm(img_b)) + for what in self.what: + try: result[what] = eval(what) + except NameError: pass + return result + + + +def threaded_loader( loader, iscuda, threads, batch_size=1, shuffle=True): + """ Get a data loader, given the dataset and some parameters. + + Parameters + ---------- + loader : object[i] returns the i-th training example. + + iscuda : bool + + batch_size : int + + threads : int + + shuffle : int + + Returns + ------- + a multi-threaded pytorch loader. + """ + return torch.utils.data.DataLoader( + loader, + batch_size = batch_size, + shuffle = shuffle, + sampler = None, + num_workers = threads, + pin_memory = iscuda, + collate_fn=collate) + + + +def collate(batch, _use_shared_memory=True): + """Puts each data field into a tensor with outer dimension batch size. + Copied from https://github.com/pytorch in torch/utils/data/_utils/collate.py + """ + import re + error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" + elem_type = type(batch[0]) + if isinstance(batch[0], torch.Tensor): + out = None + if _use_shared_memory: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = batch[0].storage()._new_shared(numel) + out = batch[0].new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + elem = batch[0] + assert elem_type.__name__ == 'ndarray' + # array of string classes and object + if re.search('[SaUO]', elem.dtype.str) is not None: + raise TypeError(error_msg.format(elem.dtype)) + batch = [torch.from_numpy(b) for b in batch] + try: + return torch.stack(batch, 0) + except RuntimeError: + return batch + elif batch[0] is None: + return list(batch) + elif isinstance(batch[0], int): + return torch.LongTensor(batch) + elif isinstance(batch[0], float): + return torch.DoubleTensor(batch) + elif isinstance(batch[0], str): + return batch + elif isinstance(batch[0], dict): + return {key: collate([d[key] for d in batch]) for key in batch[0]} + elif isinstance(batch[0], (tuple,list)): + transposed = zip(*batch) + return [collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + + +def tensor2img(tensor, model=None): + """ convert back a torch/numpy tensor to a PIL Image + by undoing the ToTensor() and Normalize() transforms. + """ + mean = norm_RGB.transforms[1].mean + std = norm_RGB.transforms[1].std + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + res = np.uint8(np.clip(255*((tensor.transpose(1,2,0) * std) + mean), 0, 255)) + from PIL import Image + return Image.fromarray(res) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser("Tool to debug/visualize the data loader") + parser.add_argument("dataloader", type=str, help="command to create the data loader") + args = parser.parse_args() + + from datasets import * + auto_pairs = lambda db: SyntheticPairDataset(db, + 'RandomScale(256,1024,can_upscale=True)', + 'RandomTilting(0.5), PixelNoise(25)') + + loader = eval(args.dataloader) + print("Data loader =", loader) + + from tools.viz import show_flow + for data in loader: + aflow = data['aflow'] + H, W = aflow.shape[-2:] + flow = (aflow - np.mgrid[:H, :W][::-1]).transpose(1,2,0) + show_flow(tensor2img(data['img1']), tensor2img(data['img2']), flow) + diff --git a/imcui/third_party/r2d2/tools/trainer.py b/imcui/third_party/r2d2/tools/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9f893395efdeb8e13cc00539325572553168c5ce --- /dev/null +++ b/imcui/third_party/r2d2/tools/trainer.py @@ -0,0 +1,76 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +from tqdm import tqdm +from collections import defaultdict + +import torch +import torch.nn as nn + + +class Trainer (nn.Module): + """ Helper class to train a deep network. + Overload this class `forward_backward` for your actual needs. + + Usage: + train = Trainer(net, loader, loss, optimizer) + for epoch in range(n_epochs): + train() + """ + def __init__(self, net, loader, loss, optimizer): + nn.Module.__init__(self) + self.net = net + self.loader = loader + self.loss_func = loss + self.optimizer = optimizer + + def iscuda(self): + return next(self.net.parameters()).device != torch.device('cpu') + + def todevice(self, x): + if isinstance(x, dict): + return {k:self.todevice(v) for k,v in x.items()} + if isinstance(x, (tuple,list)): + return [self.todevice(v) for v in x] + + if self.iscuda(): + return x.contiguous().cuda(non_blocking=True) + else: + return x.cpu() + + def __call__(self): + self.net.train() + + stats = defaultdict(list) + + for iter,inputs in enumerate(tqdm(self.loader)): + inputs = self.todevice(inputs) + + # compute gradient and do model update + self.optimizer.zero_grad() + + loss, details = self.forward_backward(inputs) + if torch.isnan(loss): + raise RuntimeError('Loss is NaN') + + self.optimizer.step() + + for key, val in details.items(): + stats[key].append( val ) + + print(" Summary of losses during this epoch:") + mean = lambda lis: sum(lis) / len(lis) + for loss_name, vals in stats.items(): + N = 1 + len(vals)//10 + print(f" - {loss_name:20}:", end='') + print(f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})") + return mean(stats['loss']) # return average loss + + def forward_backward(self, inputs): + raise NotImplementedError() + + + + diff --git a/imcui/third_party/r2d2/tools/transforms.py b/imcui/third_party/r2d2/tools/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..87275276310191a7da3fc14f606345d9616208e0 --- /dev/null +++ b/imcui/third_party/r2d2/tools/transforms.py @@ -0,0 +1,513 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +import numpy as np +from PIL import Image, ImageOps +import torchvision.transforms as tvf +import random +from math import ceil + +from . import transforms_tools as F + +''' +Example command to try out some transformation chain: + +python -m tools.transforms --trfs "Scale(384), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), RandomRotation(10), RandomTilting(0.5, 'all'), RandomScale(240,320), RandomCrop(224)" +''' + + +def instanciate_transformation(cmd_line): + ''' Create a sequence of transformations. + + cmd_line: (str) + Comma-separated list of transformations. + Ex: "Rotate(10), Scale(256)" + ''' + if not isinstance(cmd_line, str): + return cmd_line # already instanciated + + cmd_line = "tvf.Compose([%s])" % cmd_line + try: + return eval(cmd_line) + except Exception as e: + print("Cannot interpret this transform list: %s\nReason: %s" % (cmd_line, e)) + + +class Scale (object): + """ Rescale the input PIL.Image to a given size. + Copied from https://github.com/pytorch in torchvision/transforms/transforms.py + + The smallest dimension of the resulting image will be = size. + + if largest == True: same behaviour for the largest dimension. + + if not can_upscale: don't upscale + if not can_downscale: don't downscale + """ + def __init__(self, size, interpolation=Image.BILINEAR, largest=False, + can_upscale=True, can_downscale=True): + assert isinstance(size, int) or (len(size) == 2) + self.size = size + self.interpolation = interpolation + self.largest = largest + self.can_upscale = can_upscale + self.can_downscale = can_downscale + + def __repr__(self): + fmt_str = "RandomScale(%s" % str(self.size) + if self.largest: fmt_str += ', largest=True' + if not self.can_upscale: fmt_str += ', can_upscale=False' + if not self.can_downscale: fmt_str += ', can_downscale=False' + return fmt_str+')' + + def get_params(self, imsize): + w,h = imsize + if isinstance(self.size, int): + cmp = lambda a,b: (a>=b) if self.largest else (a<=b) + if (cmp(w, h) and w == self.size) or (cmp(h, w) and h == self.size): + ow, oh = w, h + elif cmp(w, h): + ow = self.size + oh = int(self.size * h / w) + else: + oh = self.size + ow = int(self.size * w / h) + else: + ow, oh = self.size + return ow, oh + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + size2 = ow, oh = self.get_params(img.size) + + if size2 != img.size: + a1, a2 = img.size, size2 + if (self.can_upscale and min(a1) < min(a2)) or (self.can_downscale and min(a1) > min(a2)): + img = img.resize(size2, self.interpolation) + + return F.update_img_and_labels(inp, img, persp=(ow/w,0,0,0,oh/h,0,0,0)) + + + +class RandomScale (Scale): + """Rescale the input PIL.Image to a random size. + Copied from https://github.com/pytorch in torchvision/transforms/transforms.py + + Args: + min_size (int): min size of the smaller edge of the picture. + max_size (int): max size of the smaller edge of the picture. + + ar (float or tuple): + max change of aspect ratio (width/height). + + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, min_size, max_size, ar=1, + can_upscale=False, can_downscale=True, interpolation=Image.BILINEAR): + Scale.__init__(self, 0, can_upscale=can_upscale, can_downscale=can_downscale, interpolation=interpolation) + assert type(min_size) == type(max_size), 'min_size and max_size can only be 2 ints or 2 floats' + assert isinstance(min_size, int) and min_size >= 1 or isinstance(min_size, float) and min_size>0 + assert isinstance(max_size, (int,float)) and min_size <= max_size + self.min_size = min_size + self.max_size = max_size + if type(ar) in (float,int): ar = (min(1/ar,ar),max(1/ar,ar)) + assert 0.2 < ar[0] <= ar[1] < 5 + self.ar = ar + + def get_params(self, imsize): + w,h = imsize + if isinstance(self.min_size, float): + min_size = int(self.min_size*min(w,h) + 0.5) + if isinstance(self.max_size, float): + max_size = int(self.max_size*min(w,h) + 0.5) + if isinstance(self.min_size, int): + min_size = self.min_size + if isinstance(self.max_size, int): + max_size = self.max_size + + if not self.can_upscale: + max_size = min(max_size,min(w,h)) + + size = int(0.5 + F.rand_log_uniform(min_size,max_size)) + ar = F.rand_log_uniform(*self.ar) # change of aspect ratio + + if w < h: # image is taller + ow = size + oh = int(0.5 + size * h / w / ar) + if oh < min_size: + ow,oh = int(0.5 + ow*float(min_size)/oh),min_size + else: # image is wider + oh = size + ow = int(0.5 + size * w / h * ar) + if ow < min_size: + ow,oh = min_size,int(0.5 + oh*float(min_size)/ow) + + assert ow >= min_size, 'image too small (width=%d < min_size=%d)' % (ow, min_size) + assert oh >= min_size, 'image too small (height=%d < min_size=%d)' % (oh, min_size) + return ow, oh + + + +class RandomCrop (object): + """Crop the given PIL Image at a random location. + Copied from https://github.com/pytorch in torchvision/transforms/transforms.py + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is 0, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. + """ + + def __init__(self, size, padding=0): + if isinstance(size, int): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __repr__(self): + return "RandomCrop(%s)" % str(self.size) + + @staticmethod + def get_params(img, output_size): + w, h = img.size + th, tw = output_size + assert h >= th and w >= tw, "Image of %dx%d is too small for crop %dx%d" % (w,h,tw,th) + + y = np.random.randint(0, h - th) if h > th else 0 + x = np.random.randint(0, w - tw) if w > tw else 0 + return x, y, tw, th + + def __call__(self, inp): + img = F.grab_img(inp) + + padl = padt = 0 + if self.padding: + if F.is_pil_image(img): + img = ImageOps.expand(img, border=self.padding, fill=0) + else: + assert isinstance(img, F.DummyImg) + img = img.expand(border=self.padding) + if isinstance(self.padding, int): + padl = padt = self.padding + else: + padl, padt = self.padding[0:2] + + i, j, tw, th = self.get_params(img, self.size) + img = img.crop((i, j, i+tw, j+th)) + + return F.update_img_and_labels(inp, img, persp=(1,0,padl-i,0,1,padt-j,0,0)) + + +class CenterCrop (RandomCrop): + """Crops the given PIL Image at the center. + Copied from https://github.com/pytorch in torchvision/transforms/transforms.py + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + @staticmethod + def get_params(img, output_size): + w, h = img.size + th, tw = output_size + y = int(0.5 +((h - th) / 2.)) + x = int(0.5 +((w - tw) / 2.)) + return x, y, tw, th + + + +class RandomRotation(object): + """Rescale the input PIL.Image to a random size. + Copied from https://github.com/pytorch in torchvision/transforms/transforms.py + + Args: + degrees (float): + rotation angle. + + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, degrees, interpolation=Image.BILINEAR): + self.degrees = degrees + self.interpolation = interpolation + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + angle = np.random.uniform(-self.degrees, self.degrees) + + img = img.rotate(angle, resample=self.interpolation) + w2, h2 = img.size + + trf = F.translate(-w/2,-h/2) + trf = F.persp_mul(trf, F.rotate(-angle * np.pi/180)) + trf = F.persp_mul(trf, F.translate(w2/2,h2/2)) + return F.update_img_and_labels(inp, img, persp=trf) + + + +class RandomTilting(object): + """Apply a random tilting (left, right, up, down) to the input PIL.Image + Copied from https://github.com/pytorch in torchvision/transforms/transforms.py + + Args: + maginitude (float): + maximum magnitude of the random skew (value between 0 and 1) + directions (string): + tilting directions allowed (all, left, right, up, down) + examples: "all", "left,right", "up-down-right" + """ + + def __init__(self, magnitude, directions='all'): + self.magnitude = magnitude + self.directions = directions.lower().replace(',',' ').replace('-',' ') + + def __repr__(self): + return "RandomTilt(%g, '%s')" % (self.magnitude,self.directions) + + def __call__(self, inp): + img = F.grab_img(inp) + w, h = img.size + + x1,y1,x2,y2 = 0,0,h,w + original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)] + + max_skew_amount = max(w, h) + max_skew_amount = int(ceil(max_skew_amount * self.magnitude)) + skew_amount = random.randint(1, max_skew_amount) + + if self.directions == 'all': + choices = [0,1,2,3] + else: + dirs = ['left', 'right', 'up', 'down'] + choices = [] + for d in self.directions.split(): + try: + choices.append(dirs.index(d)) + except: + raise ValueError('Tilting direction %s not recognized' % d) + + skew_direction = random.choice(choices) + + # print('randomtitlting: ', skew_amount, skew_direction) # to debug random + + if skew_direction == 0: + # Left Tilt + new_plane = [(y1, x1 - skew_amount), # Top Left + (y2, x1), # Top Right + (y2, x2), # Bottom Right + (y1, x2 + skew_amount)] # Bottom Left + elif skew_direction == 1: + # Right Tilt + new_plane = [(y1, x1), # Top Left + (y2, x1 - skew_amount), # Top Right + (y2, x2 + skew_amount), # Bottom Right + (y1, x2)] # Bottom Left + elif skew_direction == 2: + # Forward Tilt + new_plane = [(y1 - skew_amount, x1), # Top Left + (y2 + skew_amount, x1), # Top Right + (y2, x2), # Bottom Right + (y1, x2)] # Bottom Left + elif skew_direction == 3: + # Backward Tilt + new_plane = [(y1, x1), # Top Left + (y2, x1), # Top Right + (y2 + skew_amount, x2), # Bottom Right + (y1 - skew_amount, x2)] # Bottom Left + + # To calculate the coefficients required by PIL for the perspective skew, + # see the following Stack Overflow discussion: https://goo.gl/sSgJdj + matrix = [] + + for p1, p2 in zip(new_plane, original_plane): + matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + A = np.matrix(matrix, dtype=np.float) + B = np.array(original_plane).reshape(8) + + homography = np.dot(np.linalg.pinv(A), B) + homography = tuple(np.array(homography).reshape(8)) + #print(homography) + + img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC) + + homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3)).ravel()[:8] + return F.update_img_and_labels(inp, img, persp=tuple(homography)) + + +RandomTilt = RandomTilting # redefinition + + +class Tilt(object): + """Apply a known tilting to an image + """ + def __init__(self, *homography): + assert len(homography) == 8 + self.homography = homography + + def __call__(self, inp): + img = F.grab_img(inp) + homography = self.homography + #print(homography) + + img = img.transform(img.size, Image.PERSPECTIVE, homography, resample=Image.BICUBIC) + + homography = np.linalg.pinv(np.float32(homography+(1,)).reshape(3,3)).ravel()[:8] + return F.update_img_and_labels(inp, img, persp=tuple(homography)) + + + +class StillTransform (object): + """ Takes and return an image, without changing its shape or geometry. + """ + def _transform(self, img): + raise NotImplementedError() + + def __call__(self, inp): + img = F.grab_img(inp) + + # transform the image (size should not change) + try: + img = self._transform(img) + except TypeError: + pass + + return F.update_img_and_labels(inp, img, persp=(1,0,0,0,1,0,0,0)) + + + +class PixelNoise (StillTransform): + """ Takes an image, and add random white noise. + """ + def __init__(self, ampl=20): + StillTransform.__init__(self) + assert 0 <= ampl < 255 + self.ampl = ampl + + def __repr__(self): + return "PixelNoise(%g)" % self.ampl + + def _transform(self, img): + img = np.float32(img) + img += np.random.uniform(0.5-self.ampl/2, 0.5+self.ampl/2, size=img.shape) + return Image.fromarray(np.uint8(img.clip(0,255))) + + + +class ColorJitter (StillTransform): + """Randomly change the brightness, contrast and saturation of an image. + Copied from https://github.com/pytorch in torchvision/transforms/transforms.py + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def __repr__(self): + return "ColorJitter(%g,%g,%g,%g)" % ( + self.brightness, self.contrast, self.saturation, self.hue) + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + Arguments are same as that of __init__. + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + if brightness > 0: + brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + transforms.append(tvf.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast > 0: + contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + transforms.append(tvf.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation > 0: + saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + transforms.append(tvf.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue > 0: + hue_factor = np.random.uniform(-hue, hue) + transforms.append(tvf.Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + # print('colorjitter: ', brightness_factor, contrast_factor, saturation_factor, hue_factor) # to debug random seed + + np.random.shuffle(transforms) + transform = tvf.Compose(transforms) + + return transform + + def _transform(self, img): + transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + return transform(img) + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser("Script to try out and visualize transformations") + parser.add_argument('--img', type=str, default='imgs/test.png', help='input image') + parser.add_argument('--trfs', type=str, required=True, help='list of transformations') + parser.add_argument('--layout', type=int, nargs=2, default=(3,3), help='nb of rows,cols') + args = parser.parse_args() + + import os + args.img = args.img.replace('$HERE',os.path.dirname(__file__)) + img = Image.open(args.img) + img = dict(img=img) + + trfs = instanciate_transformation(args.trfs) + + from matplotlib import pyplot as pl + pl.ion() + pl.subplots_adjust(0,0,1,1) + + nr,nc = args.layout + + while True: + for j in range(nr): + for i in range(nc): + pl.subplot(nr,nc,i+j*nc+1) + if i==j==0: + img2 = img + else: + img2 = trfs(img.copy()) + if isinstance(img2, dict): + img2 = img2['img'] + pl.imshow(img2) + pl.xlabel("%d x %d" % img2.size) + pl.xticks(()) + pl.yticks(()) + pdb.set_trace() + + + diff --git a/imcui/third_party/r2d2/tools/transforms_tools.py b/imcui/third_party/r2d2/tools/transforms_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..294c22228a88f70480af52f79a77d73f9e5b3e1a --- /dev/null +++ b/imcui/third_party/r2d2/tools/transforms_tools.py @@ -0,0 +1,230 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +import numpy as np +from PIL import Image, ImageOps, ImageEnhance + + +class DummyImg: + ''' This class is a dummy image only defined by its size. + ''' + def __init__(self, size): + self.size = size + + def resize(self, size, *args, **kwargs): + return DummyImg(size) + + def expand(self, border): + w, h = self.size + if isinstance(border, int): + size = (w+2*border, h+2*border) + else: + l,t,r,b = border + size = (w+l+r, h+t+b) + return DummyImg(size) + + def crop(self, border): + w, h = self.size + l,t,r,b = border + assert 0 <= l <= r <= w + assert 0 <= t <= b <= h + size = (r-l, b-t) + return DummyImg(size) + + def rotate(self, angle): + raise NotImplementedError + + def transform(self, size, *args, **kwargs): + return DummyImg(size) + + +def grab_img( img_and_label ): + ''' Called to extract the image from an img_and_label input + (a dictionary). Also compatible with old-style PIL images. + ''' + if isinstance(img_and_label, dict): + # if input is a dictionary, then + # it must contains the img or its size. + try: + return img_and_label['img'] + except KeyError: + return DummyImg(img_and_label['imsize']) + + else: + # or it must be the img directly + return img_and_label + + +def update_img_and_labels(img_and_label, img, persp=None): + ''' Called to update the img_and_label + ''' + if isinstance(img_and_label, dict): + img_and_label['img'] = img + img_and_label['imsize'] = img.size + + if persp: + if 'persp' not in img_and_label: + img_and_label['persp'] = (1,0,0,0,1,0,0,0) + img_and_label['persp'] = persp_mul(persp, img_and_label['persp']) + + return img_and_label + + else: + # or it must be the img directly + return img + + +def rand_log_uniform(a, b): + return np.exp(np.random.uniform(np.log(a),np.log(b))) + + +def translate(tx, ty): + return (1,0,tx, + 0,1,ty, + 0,0) + +def rotate(angle): + return (np.cos(angle),-np.sin(angle), 0, + np.sin(angle), np.cos(angle), 0, + 0, 0) + + +def persp_mul(mat, mat2): + ''' homography (perspective) multiplication. + mat: 8-tuple (homography transform) + mat2: 8-tuple (homography transform) or 2-tuple (point) + ''' + assert isinstance(mat, tuple) + assert isinstance(mat2, tuple) + + mat = np.float32(mat+(1,)).reshape(3,3) + mat2 = np.array(mat2+(1,)).reshape(3,3) + res = np.dot(mat, mat2) + return tuple((res/res[2,2]).ravel()[:8]) + + +def persp_apply(mat, pts): + ''' homography (perspective) transformation. + mat: 8-tuple (homography transform) + pts: numpy array + ''' + assert isinstance(mat, tuple) + assert isinstance(pts, np.ndarray) + assert pts.shape[-1] == 2 + mat = np.float32(mat+(1,)).reshape(3,3) + + if pts.ndim == 1: + pt = np.dot(pts, mat[:,:2].T).ravel() + mat[:,2] + pt /= pt[2] # homogeneous coordinates + return tuple(pt[:2]) + else: + pt = np.dot(pts, mat[:,:2].T) + mat[:,2] + pt[:,:2] /= pt[:,2:3] # homogeneous coordinates + return pt[:,:2] + + +def is_pil_image(img): + return isinstance(img, Image.Image) + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + Args: + img (PIL Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + Returns: + PIL Image: Brightness adjusted image. + Copied from https://github.com/pytorch in torchvision/transforms/functional.py + """ + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + Returns: + PIL Image: Contrast adjusted image. + Copied from https://github.com/pytorch in torchvision/transforms/functional.py + """ + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + Returns: + PIL Image: Saturation adjusted image. + Copied from https://github.com/pytorch in torchvision/transforms/functional.py + """ + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + See https://en.wikipedia.org/wiki/Hue for more details on Hue. + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + Returns: + PIL Image: Hue adjusted image. + Copied from https://github.com/pytorch in torchvision/transforms/functional.py + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + if not is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + + diff --git a/imcui/third_party/r2d2/tools/viz.py b/imcui/third_party/r2d2/tools/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..c86103f3aeb468fca8b0ac9a412f22b85239361b --- /dev/null +++ b/imcui/third_party/r2d2/tools/viz.py @@ -0,0 +1,191 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import pdb +import numpy as np +import matplotlib.pyplot as pl + + +def make_colorwheel(): + ''' + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py + Copyright (c) 2018 Tom Runia + ''' + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_compute_color(u, v, convert_to_bgr=False): + ''' + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param u: np.ndarray, input horizontal flow + :param v: np.ndarray, input vertical flow + :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB + :return: + + Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py + Copyright (c) 2018 Tom Runia + ''' + + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + + for i in range(colorwheel.shape[1]): + + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range? + + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + + return flow_image + + +def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): + ''' + Expects a two dimensional flow image of shape [H,W,2] + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param flow_uv: np.ndarray of shape [H,W,2] + :param clip_flow: float, maximum clipping value for flow + :return: + + Copied from https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/flow_vis.py + Copyright (c) 2018 Tom Runia + ''' + + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + + return flow_compute_color(u, v, convert_to_bgr) + + + +def show_flow( img0, img1, flow, mask=None ): + img0 = np.asarray(img0) + img1 = np.asarray(img1) + if mask is None: mask = 1 + mask = np.asarray(mask) + if mask.ndim == 2: mask = mask[:,:,None] + assert flow.ndim == 3 + assert flow.shape[:2] == img0.shape[:2] and flow.shape[2] == 2 + + def noticks(): + pl.xticks([]) + pl.yticks([]) + fig = pl.figure("showing correspondences") + ax1 = pl.subplot(221) + ax1.numaxis = 0 + pl.imshow(img0*mask) + noticks() + ax2 = pl.subplot(222) + ax2.numaxis = 1 + pl.imshow(img1) + noticks() + + ax = pl.subplot(212) + ax.numaxis = 0 + flow_img = flow_to_color(np.where(np.isnan(flow), 0, flow)) + pl.imshow(flow_img * mask) + noticks() + + pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, wspace=0.02, hspace=0.02) + + def motion_notify_callback(event): + if event.inaxes is None: return + x,y = event.xdata, event.ydata + ax1.lines = [] + ax2.lines = [] + try: + x,y = int(x+0.5), int(y+0.5) + ax1.plot(x,y,'+',ms=10,mew=2,color='blue',scalex=False,scaley=False) + x,y = flow[y,x] + (x,y) + ax2.plot(x,y,'+',ms=10,mew=2,color='red',scalex=False,scaley=False) + # we redraw only the concerned axes + renderer = fig.canvas.get_renderer() + ax1.draw(renderer) + ax2.draw(renderer) + fig.canvas.blit(ax1.bbox) + fig.canvas.blit(ax2.bbox) + except IndexError: + return + + cid_move = fig.canvas.mpl_connect('motion_notify_event',motion_notify_callback) + print("Move your mouse over the images to show matches (ctrl-C to quit)") + pl.show() + + diff --git a/imcui/third_party/r2d2/train.py b/imcui/third_party/r2d2/train.py new file mode 100644 index 0000000000000000000000000000000000000000..10d23d9e40ebe8cb10c4d548b7fcb5c1c0fd7739 --- /dev/null +++ b/imcui/third_party/r2d2/train.py @@ -0,0 +1,138 @@ +# Copyright 2019-present NAVER Corp. +# CC BY-NC-SA 3.0 +# Available only for non-commercial use + +import os, pdb +import torch +import torch.optim as optim + +from tools import common, trainer +from tools.dataloader import * +from nets.patchnet import * +from nets.losses import * + +default_net = "Quad_L2Net_ConfCFS()" + +toy_db_debug = """SyntheticPairDataset( + ImgFolder('imgs'), + 'RandomScale(256,1024,can_upscale=True)', + 'RandomTilting(0.5), PixelNoise(25)')""" + +db_web_images = """SyntheticPairDataset( + web_images, + 'RandomScale(256,1024,can_upscale=True)', + 'RandomTilting(0.5), PixelNoise(25)')""" + +db_aachen_images = """SyntheticPairDataset( + aachen_db_images, + 'RandomScale(256,1024,can_upscale=True)', + 'RandomTilting(0.5), PixelNoise(25)')""" + +db_aachen_style_transfer = """TransformedPairs( + aachen_style_transfer_pairs, + 'RandomScale(256,1024,can_upscale=True), RandomTilting(0.5), PixelNoise(25)')""" + +db_aachen_flow = "aachen_flow_pairs" + +data_sources = dict( + D = toy_db_debug, + W = db_web_images, + A = db_aachen_images, + F = db_aachen_flow, + S = db_aachen_style_transfer, + ) + +default_dataloader = """PairLoader(CatPairDataset(`data`), + scale = 'RandomScale(256,1024,can_upscale=True)', + distort = 'ColorJitter(0.2,0.2,0.2,0.1)', + crop = 'RandomCrop(192)')""" + +default_sampler = """NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, + subd_neg=-8,maxpool_pos=True)""" + +default_loss = """MultiLoss( + 1, ReliabilityLoss(`sampler`, base=0.5, nq=20), + 1, CosimLoss(N=`N`), + 1, PeakyLoss(N=`N`))""" + + +class MyTrainer(trainer.Trainer): + """ This class implements the network training. + Below is the function I need to overload to explain how to do the backprop. + """ + def forward_backward(self, inputs): + output = self.net(imgs=[inputs.pop('img1'),inputs.pop('img2')]) + allvars = dict(inputs, **output) + loss, details = self.loss_func(**allvars) + if torch.is_grad_enabled(): loss.backward() + return loss, details + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser("Train R2D2") + + parser.add_argument("--data-loader", type=str, default=default_dataloader) + parser.add_argument("--train-data", type=str, default=list('WASF'), nargs='+', + choices = set(data_sources.keys())) + parser.add_argument("--net", type=str, default=default_net, help='network architecture') + + parser.add_argument("--pretrained", type=str, default="", help='pretrained model path') + parser.add_argument("--save-path", type=str, required=True, help='model save_path path') + + parser.add_argument("--loss", type=str, default=default_loss, help="loss function") + parser.add_argument("--sampler", type=str, default=default_sampler, help="AP sampler") + parser.add_argument("--N", type=int, default=16, help="patch size for repeatability") + + parser.add_argument("--epochs", type=int, default=25, help='number of training epochs') + parser.add_argument("--batch-size", "--bs", type=int, default=8, help="batch size") + parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4) + parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4) + + parser.add_argument("--threads", type=int, default=8, help='number of worker threads') + parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='-1 for CPU') + + args = parser.parse_args() + + iscuda = common.torch_set_gpu(args.gpu) + common.mkdir_for(args.save_path) + + # Create data loader + from datasets import * + db = [data_sources[key] for key in args.train_data] + db = eval(args.data_loader.replace('`data`',','.join(db)).replace('\n','')) + print("Training image database =", db) + loader = threaded_loader(db, iscuda, args.threads, args.batch_size, shuffle=True) + + # create network + print("\n>> Creating net = " + args.net) + net = eval(args.net) + print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") + + # initialization + if args.pretrained: + checkpoint = torch.load(args.pretrained, lambda a,b:a) + net.load_pretrained(checkpoint['state_dict']) + + # create losses + loss = args.loss.replace('`sampler`',args.sampler).replace('`N`',str(args.N)) + print("\n>> Creating loss = " + loss) + loss = eval(loss.replace('\n','')) + + # create optimizer + optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad], + lr=args.learning_rate, weight_decay=args.weight_decay) + + train = MyTrainer(net, loader, loss, optimizer) + if iscuda: train = train.cuda() + + # Training loop # + for epoch in range(args.epochs): + print(f"\n>> Starting epoch {epoch}...") + train() + + print(f"\n>> Saving model to {args.save_path}") + torch.save({'net': args.net, 'state_dict': net.state_dict()}, args.save_path) + + diff --git a/imcui/third_party/r2d2/viz_heatmaps.py b/imcui/third_party/r2d2/viz_heatmaps.py new file mode 100644 index 0000000000000000000000000000000000000000..42705e70ecea82696a0d784b274f7f387fdf6595 --- /dev/null +++ b/imcui/third_party/r2d2/viz_heatmaps.py @@ -0,0 +1,122 @@ +import pdb +import os +import sys +import tqdm + +import numpy as np +import torch + +from PIL import Image +from matplotlib import pyplot as pl; pl.ion() +from scipy.ndimage import uniform_filter +smooth = lambda arr: uniform_filter(arr, 3) + +def transparent(img, alpha, cmap, **kw): + from matplotlib.colors import Normalize + colored_img = cmap(Normalize(clip=True,**kw)(img)) + colored_img[:,:,-1] = alpha + return colored_img + +from tools import common +from tools.dataloader import norm_RGB +from nets.patchnet import * +from extract import NonMaxSuppression + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") + + parser.add_argument("--img", type=str, default="imgs/brooklyn.png") + parser.add_argument("--resize", type=int, default=512) + parser.add_argument("--out", type=str, default="viz.png") + + parser.add_argument("--checkpoint", type=str, required=True, help='network path') + parser.add_argument("--net", type=str, default="", help='network command') + + parser.add_argument("--max-kpts", type=int, default=200) + parser.add_argument("--reliability-thr", type=float, default=0.8) + parser.add_argument("--repeatability-thr", type=float, default=0.7) + parser.add_argument("--border", type=int, default=20,help='rm keypoints close to border') + + parser.add_argument("--gpu", type=int, nargs='+', required=True, help='-1 for CPU') + parser.add_argument("--dbg", type=str, nargs='+', default=(), help='debug options') + + args = parser.parse_args() + args.dbg = set(args.dbg) + + iscuda = common.torch_set_gpu(args.gpu) + device = torch.device('cuda' if iscuda else 'cpu') + + # create network + checkpoint = torch.load(args.checkpoint, lambda a,b:a) + args.net = args.net or checkpoint['net'] + print("\n>> Creating net = " + args.net) + net = eval(args.net) + net.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()}) + if iscuda: net = net.cuda() + print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") + + img = Image.open(args.img).convert('RGB') + if args.resize: img.thumbnail((args.resize,args.resize)) + img = np.asarray(img) + + detector = NonMaxSuppression( + rel_thr = args.reliability_thr, + rep_thr = args.repeatability_thr) + + with torch.no_grad(): + print(">> computing features...") + res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) + rela = res.get('reliability') + repe = res.get('repeatability') + kpts = detector(**res).T[:,[1,0]] + kpts = kpts[repe[0][0,0][kpts[:,1],kpts[:,0]].argsort()[-args.max_kpts:]] + + fig = pl.figure("viz") + kw = dict(cmap=pl.cm.RdYlGn, vmax=1) + crop = (slice(args.border,-args.border or 1),)*2 + + if 'reliability' in args.dbg: + + ax1 = pl.subplot(131) + pl.imshow(img[crop], cmap=pl.cm.gray) + pl.xticks(()); pl.yticks(()) + + pl.subplot(132) + pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) + pl.xticks(()); pl.yticks(()) + + x,y = kpts[:,0:2].cpu().numpy().T - args.border + pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) + + ax1 = pl.subplot(133) + rela = rela[0][0,0].cpu().numpy() + pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) + pl.xticks(()); pl.yticks(()) + + else: + ax1 = pl.subplot(131) + pl.imshow(img[crop], cmap=pl.cm.gray) + pl.xticks(()); pl.yticks(()) + + x,y = kpts[:,0:2].cpu().numpy().T - args.border + pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) + + pl.subplot(132) + pl.imshow(img[crop], cmap=pl.cm.gray) + pl.xticks(()); pl.yticks(()) + c = repe[0][0,0].cpu().numpy() + pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) + + ax1 = pl.subplot(133) + pl.imshow(img[crop], cmap=pl.cm.gray) + pl.xticks(()); pl.yticks(()) + rela = rela[0][0,0].cpu().numpy() + pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) + + pl.gcf().set_size_inches(9, 2.73) + pl.subplots_adjust(0.01,0.01,0.99,0.99,hspace=0.1) + pl.savefig(args.out) + pdb.set_trace() + diff --git a/imcui/ui/__init__.py b/imcui/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6ccf52978e85f5abaca55d6559c74a6b2bd169 --- /dev/null +++ b/imcui/ui/__init__.py @@ -0,0 +1,5 @@ +__version__ = "1.0.1" + + +def get_version(): + return __version__ diff --git a/imcui/ui/app_class.py b/imcui/ui/app_class.py new file mode 100644 index 0000000000000000000000000000000000000000..3e8efca74ad0a138222ad93d9da5b5ba8bd88605 --- /dev/null +++ b/imcui/ui/app_class.py @@ -0,0 +1,810 @@ +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import gradio as gr +import numpy as np +from easydict import EasyDict as edict +from omegaconf import OmegaConf + +from .sfm import SfmEngine +from .utils import ( + GRADIO_VERSION, + gen_examples, + generate_warp_images, + get_matcher_zoo, + load_config, + ransac_zoo, + run_matching, + run_ransac, + send_to_match, +) + +DESCRIPTION = """ +# Image Matching WebUI +This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue! +
+🔎 For more details about supported local features and matchers, please refer to https://github.com/Vincentqyw/image-matching-webui + +🚀 All algorithms run on CPU for inference, causing slow speeds and high latency. For faster inference, please download the [source code](https://github.com/Vincentqyw/image-matching-webui) for local deployment. + +🐛 Your feedback is valuable to me. Please do not hesitate to report any bugs [here](https://github.com/Vincentqyw/image-matching-webui/issues). +""" + +CSS = """ +#warning {background-color: #FFCCCB} +.logs_class textarea {font-size: 12px !important} +""" + + +class ImageMatchingApp: + def __init__(self, server_name="0.0.0.0", server_port=7860, **kwargs): + self.server_name = server_name + self.server_port = server_port + self.config_path = kwargs.get("config", Path(__file__).parent / "config.yaml") + self.cfg = load_config(self.config_path) + self.matcher_zoo = get_matcher_zoo(self.cfg["matcher_zoo"]) + self.app = None + self.example_data_root = kwargs.get( + "example_data_root", Path(__file__).parents[1] / "datasets" + ) + self.init_interface() + + def init_matcher_dropdown(self): + algos = [] + for k, v in self.cfg["matcher_zoo"].items(): + if v.get("enable", True): + algos.append(k) + return algos + + def init_interface(self): + with gr.Blocks(css=CSS) as self.app: + with gr.Tab("Image Matching"): + with gr.Row(): + with gr.Column(scale=1): + gr.Image( + str(Path(__file__).parent.parent / "assets/logo.webp"), + elem_id="logo-img", + show_label=False, + show_share_button=False, + show_download_button=False, + ) + with gr.Column(scale=3): + gr.Markdown(DESCRIPTION) + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + matcher_list = gr.Dropdown( + choices=self.init_matcher_dropdown(), + value="disk+lightglue", + label="Matching Model", + interactive=True, + ) + match_image_src = gr.Radio( + ( + ["upload", "webcam", "clipboard"] + if GRADIO_VERSION > "3" + else ["upload", "webcam", "canvas"] + ), + label="Image Source", + value="upload", + ) + with gr.Row(): + input_image0 = gr.Image( + label="Image 0", + type="numpy", + image_mode="RGB", + height=300 if GRADIO_VERSION > "3" else None, + interactive=True, + ) + input_image1 = gr.Image( + label="Image 1", + type="numpy", + image_mode="RGB", + height=300 if GRADIO_VERSION > "3" else None, + interactive=True, + ) + + with gr.Row(): + button_reset = gr.Button(value="Reset") + button_run = gr.Button(value="Run Match", variant="primary") + + with gr.Accordion("Advanced Setting", open=False): + with gr.Accordion("Image Setting", open=True): + with gr.Row(): + image_force_resize_cb = gr.Checkbox( + label="Force Resize", + value=False, + interactive=True, + ) + image_setting_height = gr.Slider( + minimum=48, + maximum=2048, + step=16, + label="Image Height", + value=480, + visible=False, + ) + image_setting_width = gr.Slider( + minimum=64, + maximum=2048, + step=16, + label="Image Width", + value=640, + visible=False, + ) + with gr.Accordion("Matching Setting", open=True): + with gr.Row(): + match_setting_threshold = gr.Slider( + minimum=0.0, + maximum=1, + step=0.001, + label="Match threshold", + value=0.1, + ) + match_setting_max_keypoints = gr.Slider( + minimum=10, + maximum=10000, + step=10, + label="Max features", + value=1000, + ) + # TODO: add line settings + with gr.Row(): + detect_keypoints_threshold = gr.Slider( + minimum=0, + maximum=1, + step=0.001, + label="Keypoint threshold", + value=0.015, + ) + detect_line_threshold = ( # noqa: F841 + gr.Slider( + minimum=0.1, + maximum=1, + step=0.01, + label="Line threshold", + value=0.2, + ) + ) + # matcher_lists = gr.Radio( + # ["NN-mutual", "Dual-Softmax"], + # label="Matcher mode", + # value="NN-mutual", + # ) + with gr.Accordion("RANSAC Setting", open=True): + with gr.Row(equal_height=False): + ransac_method = gr.Dropdown( + choices=ransac_zoo.keys(), + value=self.cfg["defaults"]["ransac_method"], + label="RANSAC Method", + interactive=True, + ) + ransac_reproj_threshold = gr.Slider( + minimum=0.0, + maximum=12, + step=0.01, + label="Ransac Reproj threshold", + value=8.0, + ) + ransac_confidence = gr.Slider( + minimum=0.0, + maximum=1, + step=0.00001, + label="Ransac Confidence", + value=self.cfg["defaults"]["ransac_confidence"], + ) + ransac_max_iter = gr.Slider( + minimum=0.0, + maximum=100000, + step=100, + label="Ransac Iterations", + value=self.cfg["defaults"]["ransac_max_iter"], + ) + button_ransac = gr.Button( + value="Rerun RANSAC", variant="primary" + ) + with gr.Accordion("Geometry Setting", open=False): + with gr.Row(equal_height=False): + choice_geometry_type = gr.Radio( + ["Fundamental", "Homography"], + label="Reconstruct Geometry", + value=self.cfg["defaults"]["setting_geometry"], + ) + # image resize + image_force_resize_cb.select( + fn=self._on_select_force_resize, + inputs=image_force_resize_cb, + outputs=[image_setting_width, image_setting_height], + ) + # collect inputs + state_cache = gr.State({}) + inputs = [ + input_image0, + input_image1, + match_setting_threshold, + match_setting_max_keypoints, + detect_keypoints_threshold, + matcher_list, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + choice_geometry_type, + gr.State(self.matcher_zoo), + image_force_resize_cb, + image_setting_width, + image_setting_height, + ] + + # Add some examples + with gr.Row(): + # Example inputs + with gr.Accordion("Open for More: Examples", open=True): + gr.Examples( + examples=gen_examples(self.example_data_root), + inputs=inputs, + outputs=[], + fn=run_matching, + cache_examples=False, + label=( + "Examples (click one of the images below to Run" + " Match). Thx: WxBS" + ), + ) + with gr.Accordion("Supported Algorithms", open=False): + # add a table of supported algorithms + self.display_supported_algorithms() + + with gr.Column(): + with gr.Accordion("Open for More: Keypoints", open=True): + output_keypoints = gr.Image(label="Keypoints", type="numpy") + with gr.Accordion( + ( + "Open for More: Raw Matches" + " (Green for good matches, Red for bad)" + ), + open=False, + ): + output_matches_raw = gr.Image( + label="Raw Matches", + type="numpy", + ) + with gr.Accordion( + ( + "Open for More: Ransac Matches" + " (Green for good matches, Red for bad)" + ), + open=True, + ): + output_matches_ransac = gr.Image( + label="Ransac Matches", type="numpy" + ) + with gr.Accordion( + "Open for More: Matches Statistics", open=False + ): + output_pred = gr.File(label="Outputs", elem_id="download") + matches_result_info = gr.JSON(label="Matches Statistics") + matcher_info = gr.JSON(label="Match info") + + with gr.Accordion("Open for More: Warped Image", open=True): + output_wrapped = gr.Image( + label="Wrapped Pair", type="numpy" + ) + # send to input + button_rerun = gr.Button( + value="Send to Input Match Pair", + variant="primary", + ) + with gr.Accordion( + "Open for More: Geometry info", open=False + ): + geometry_result = gr.JSON( + label="Reconstructed Geometry" + ) + + # callbacks + match_image_src.change( + fn=self.ui_change_imagebox, + inputs=match_image_src, + outputs=input_image0, + ) + match_image_src.change( + fn=self.ui_change_imagebox, + inputs=match_image_src, + outputs=input_image1, + ) + # collect outputs + outputs = [ + output_keypoints, + output_matches_raw, + output_matches_ransac, + matches_result_info, + matcher_info, + geometry_result, + output_wrapped, + state_cache, + output_pred, + ] + # button callbacks + button_run.click(fn=run_matching, inputs=inputs, outputs=outputs) + # Reset images + reset_outputs = [ + input_image0, + input_image1, + match_setting_threshold, + match_setting_max_keypoints, + detect_keypoints_threshold, + matcher_list, + input_image0, + input_image1, + match_image_src, + output_keypoints, + output_matches_raw, + output_matches_ransac, + matches_result_info, + matcher_info, + output_wrapped, + geometry_result, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + choice_geometry_type, + output_pred, + image_force_resize_cb, + ] + button_reset.click( + fn=self.ui_reset_state, + inputs=None, + outputs=reset_outputs, + ) + + # run ransac button action + button_ransac.click( + fn=run_ransac, + inputs=[ + state_cache, + choice_geometry_type, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + ], + outputs=[ + output_matches_ransac, + matches_result_info, + output_wrapped, + output_pred, + ], + ) + + # send warped image to match + button_rerun.click( + fn=send_to_match, + inputs=[state_cache], + outputs=[input_image0, input_image1], + ) + + # estimate geo + choice_geometry_type.change( + fn=generate_warp_images, + inputs=[ + input_image0, + input_image1, + geometry_result, + choice_geometry_type, + ], + outputs=[output_wrapped, geometry_result], + ) + with gr.Tab("Structure from Motion(under-dev)"): + sfm_ui = AppSfmUI( # noqa: F841 + { + **self.cfg, + "matcher_zoo": self.matcher_zoo, + "outputs": "experiments/sfm", + } + ) + sfm_ui.call_empty() + + def run(self): + self.app.queue().launch( + server_name=self.server_name, + server_port=self.server_port, + share=False, + allowed_paths=[ + str(Path(__file__).parents[0]), + str(Path(__file__).parents[1]), + ], + ) + + def ui_change_imagebox(self, choice): + """ + Updates the image box with the given choice. + + Args: + choice (list): The list of image sources to be displayed in the image box. + + Returns: + dict: A dictionary containing the updated value, sources, and type for the image box. + """ + ret_dict = { + "value": None, # The updated value of the image box + "__type__": "update", # The type of update for the image box + } + if GRADIO_VERSION > "3": + return { + **ret_dict, + "sources": choice, # The list of image sources to be displayed + } + else: + return { + **ret_dict, + "source": choice, # The list of image sources to be displayed + } + + def _on_select_force_resize(self, visible: bool = False): + return gr.update(visible=visible), gr.update(visible=visible) + + def ui_reset_state( + self, + *args: Any, + ) -> Tuple[ + Optional[np.ndarray], + Optional[np.ndarray], + float, + int, + float, + str, + Dict[str, Any], + Dict[str, Any], + str, + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + Dict[str, Any], + Dict[str, Any], + Optional[np.ndarray], + Dict[str, Any], + str, + int, + float, + int, + bool, + ]: + """ + Reset the state of the UI. + + Returns: + tuple: A tuple containing the initial values for the UI state. + """ + key: str = list(self.matcher_zoo.keys())[ + 0 + ] # Get the first key from matcher_zoo + # flush_logs() + return ( + None, # image0: Optional[np.ndarray] + None, # image1: Optional[np.ndarray] + self.cfg["defaults"]["match_threshold"], # matching_threshold: float + self.cfg["defaults"]["max_keypoints"], # max_keypoints: int + self.cfg["defaults"]["keypoint_threshold"], # keypoint_threshold: float + key, # matcher: str + self.ui_change_imagebox("upload"), # input image0: Dict[str, Any] + self.ui_change_imagebox("upload"), # input image1: Dict[str, Any] + "upload", # match_image_src: str + None, # keypoints: Optional[np.ndarray] + None, # raw matches: Optional[np.ndarray] + None, # ransac matches: Optional[np.ndarray] + {}, # matches result info: Dict[str, Any] + {}, # matcher config: Dict[str, Any] + None, # warped image: Optional[np.ndarray] + {}, # geometry result: Dict[str, Any] + self.cfg["defaults"]["ransac_method"], # ransac_method: str + self.cfg["defaults"][ + "ransac_reproj_threshold" + ], # ransac_reproj_threshold: float + self.cfg["defaults"]["ransac_confidence"], # ransac_confidence: float + self.cfg["defaults"]["ransac_max_iter"], # ransac_max_iter: int + self.cfg["defaults"]["setting_geometry"], # geometry: str + None, # predictions + False, + ) + + def display_supported_algorithms(self, style="tab"): + def get_link(link, tag="Link"): + return "[{}]({})".format(tag, link) if link is not None else "None" + + data = [] + cfg = self.cfg["matcher_zoo"] + if style == "md": + markdown_table = "| Algo. | Conference | Code | Project | Paper |\n" + markdown_table += "| ----- | ---------- | ---- | ------- | ----- |\n" + + for k, v in cfg.items(): + if not v["info"]["display"]: + continue + github_link = get_link(v["info"]["github"]) + project_link = get_link(v["info"]["project"]) + paper_link = get_link( + v["info"]["paper"], + ( + Path(v["info"]["paper"]).name[-10:] + if v["info"]["paper"] is not None + else "Link" + ), + ) + + markdown_table += "{}|{}|{}|{}|{}\n".format( + v["info"]["name"], # display name + v["info"]["source"], + github_link, + project_link, + paper_link, + ) + return gr.Markdown(markdown_table) + elif style == "tab": + for k, v in cfg.items(): + if not v["info"].get("display", True): + continue + data.append( + [ + v["info"]["name"], + v["info"]["source"], + v["info"]["github"], + v["info"]["paper"], + v["info"]["project"], + ] + ) + tab = gr.Dataframe( + headers=["Algo.", "Conference", "Code", "Paper", "Project"], + datatype=["str", "str", "str", "str", "str"], + col_count=(5, "fixed"), + value=data, + # wrap=True, + # min_width = 1000, + # height=1000, + ) + return tab + + +class AppBaseUI: + def __init__(self, cfg: Dict[str, Any] = {}): + self.cfg = OmegaConf.create(cfg) + self.inputs = edict({}) + self.outputs = edict({}) + self.ui = edict({}) + + def _init_ui(self): + NotImplemented + + def call(self, **kwargs): + NotImplemented + + def info(self): + gr.Info("SFM is under construction.") + + +class AppSfmUI(AppBaseUI): + def __init__(self, cfg: Dict[str, Any] = None): + super().__init__(cfg) + assert "matcher_zoo" in self.cfg + self.matcher_zoo = self.cfg["matcher_zoo"] + self.sfm_engine = SfmEngine(cfg) + self._init_ui() + + def init_retrieval_dropdown(self): + algos = [] + for k, v in self.cfg["retrieval_zoo"].items(): + if v.get("enable", True): + algos.append(k) + return algos + + def _update_options(self, option): + if option == "sparse": + return gr.Textbox("sparse", visible=True) + elif option == "dense": + return gr.Textbox("dense", visible=True) + else: + return gr.Textbox("not set", visible=True) + + def _on_select_custom_params(self, value: bool = False): + return gr.update(visible=value) + + def _init_ui(self): + with gr.Row(): + # data settting and camera settings + with gr.Column(): + self.inputs.input_images = gr.File( + label="SfM", + interactive=True, + file_count="multiple", + min_width=300, + ) + # camera setting + with gr.Accordion("Camera Settings", open=True): + with gr.Column(): + with gr.Row(): + with gr.Column(): + self.inputs.camera_model = gr.Dropdown( + choices=[ + "PINHOLE", + "SIMPLE_RADIAL", + "OPENCV", + ], + value="PINHOLE", + label="Camera Model", + interactive=True, + ) + with gr.Column(): + gr.Checkbox( + label="Shared Params", + value=True, + interactive=True, + ) + camera_custom_params_cb = gr.Checkbox( + label="Custom Params", + value=False, + interactive=True, + ) + with gr.Row(): + self.inputs.camera_params = gr.Textbox( + label="Camera Params", + value="0,0,0,0", + interactive=False, + visible=False, + ) + camera_custom_params_cb.select( + fn=self._on_select_custom_params, + inputs=camera_custom_params_cb, + outputs=self.inputs.camera_params, + ) + + with gr.Accordion("Matching Settings", open=True): + # feature extraction and matching setting + with gr.Row(): + # matcher setting + self.inputs.matcher_key = gr.Dropdown( + choices=self.matcher_zoo.keys(), + value="disk+lightglue", + label="Matching Model", + interactive=True, + ) + with gr.Row(): + with gr.Accordion("Advanced Settings", open=False): + with gr.Column(): + with gr.Row(): + # matching setting + self.inputs.max_keypoints = gr.Slider( + label="Max Keypoints", + minimum=100, + maximum=10000, + value=1000, + interactive=True, + ) + self.inputs.keypoint_threshold = gr.Slider( + label="Keypoint Threshold", + minimum=0, + maximum=1, + value=0.01, + ) + with gr.Row(): + self.inputs.match_threshold = gr.Slider( + label="Match Threshold", + minimum=0.01, + maximum=12.0, + value=0.2, + ) + self.inputs.ransac_threshold = gr.Slider( + label="Ransac Threshold", + minimum=0.01, + maximum=12.0, + value=4.0, + step=0.01, + interactive=True, + ) + + with gr.Row(): + self.inputs.ransac_confidence = gr.Slider( + label="Ransac Confidence", + minimum=0.01, + maximum=1.0, + value=0.9999, + step=0.0001, + interactive=True, + ) + self.inputs.ransac_max_iter = gr.Slider( + label="Ransac Max Iter", + minimum=1, + maximum=100, + value=100, + step=1, + interactive=True, + ) + with gr.Accordion("Scene Graph Settings", open=True): + # mapping setting + self.inputs.scene_graph = gr.Dropdown( + choices=["all", "swin", "oneref"], + value="all", + label="Scene Graph", + interactive=True, + ) + + # global feature setting + self.inputs.global_feature = gr.Dropdown( + choices=self.init_retrieval_dropdown(), + value="netvlad", + label="Global features", + interactive=True, + ) + self.inputs.top_k = gr.Slider( + label="Number of Images per Image to Match", + minimum=1, + maximum=100, + value=10, + step=1, + ) + # button_match = gr.Button("Run Matching", variant="primary") + + # mapping setting + with gr.Column(): + with gr.Accordion("Mapping Settings", open=True): + with gr.Row(): + with gr.Accordion("Buddle Settings", open=True): + with gr.Row(): + self.inputs.mapper_refine_focal_length = gr.Checkbox( + label="Refine Focal Length", + value=False, + interactive=True, + ) + self.inputs.mapper_refine_principle_points = ( + gr.Checkbox( + label="Refine Principle Points", + value=False, + interactive=True, + ) + ) + self.inputs.mapper_refine_extra_params = gr.Checkbox( + label="Refine Extra Params", + value=False, + interactive=True, + ) + with gr.Accordion("Retriangluation Settings", open=True): + gr.Textbox( + label="Retriangluation Details", + ) + self.ui.button_sfm = gr.Button("Run SFM", variant="primary") + self.outputs.model_3d = gr.Model3D( + interactive=True, + ) + self.outputs.output_image = gr.Image( + label="SFM Visualize", + type="numpy", + image_mode="RGB", + interactive=False, + ) + + def call_empty(self): + self.ui.button_sfm.click(fn=self.info, inputs=[], outputs=[]) + + def call(self): + self.ui.button_sfm.click( + fn=self.sfm_engine.call, + inputs=[ + self.inputs.matcher_key, + self.inputs.input_images, # images + self.inputs.camera_model, + self.inputs.camera_params, + self.inputs.max_keypoints, + self.inputs.keypoint_threshold, + self.inputs.match_threshold, + self.inputs.ransac_threshold, + self.inputs.ransac_confidence, + self.inputs.ransac_max_iter, + self.inputs.scene_graph, + self.inputs.global_feature, + self.inputs.top_k, + self.inputs.mapper_refine_focal_length, + self.inputs.mapper_refine_principle_points, + self.inputs.mapper_refine_extra_params, + ], + outputs=[self.outputs.model_3d, self.outputs.output_image], + ) diff --git a/imcui/ui/sfm.py b/imcui/ui/sfm.py new file mode 100644 index 0000000000000000000000000000000000000000..b52924fb3b8991bd7f589be1a7a6bb06c3f45469 --- /dev/null +++ b/imcui/ui/sfm.py @@ -0,0 +1,164 @@ +import shutil +import tempfile +from pathlib import Path +from typing import Any, Dict, List + + +from ..hloc import ( + extract_features, + logger, + match_features, + pairs_from_retrieval, + reconstruction, + visualization, +) + +try: + import pycolmap +except ImportError: + logger.warning("pycolmap not installed, some features may not work") + +from .viz import fig2im + + +class SfmEngine: + def __init__(self, cfg: Dict[str, Any] = None): + self.cfg = cfg + if "outputs" in cfg and Path(cfg["outputs"]): + outputs = Path(cfg["outputs"]) + outputs.mkdir(parents=True, exist_ok=True) + else: + outputs = tempfile.mkdtemp() + self.outputs = Path(outputs) + + def call( + self, + key: str, + images: Path, + camera_model: str, + camera_params: List[float], + max_keypoints: int, + keypoint_threshold: float, + match_threshold: float, + ransac_threshold: int, + ransac_confidence: float, + ransac_max_iter: int, + scene_graph: bool, + global_feature: str, + top_k: int = 10, + mapper_refine_focal_length: bool = False, + mapper_refine_principle_points: bool = False, + mapper_refine_extra_params: bool = False, + ): + """ + Call a list of functions to perform feature extraction, matching, and reconstruction. + + Args: + key (str): The key to retrieve the matcher and feature models. + images (Path): The directory containing the images. + outputs (Path): The directory to store the outputs. + camera_model (str): The camera model. + camera_params (List[float]): The camera parameters. + max_keypoints (int): The maximum number of features. + match_threshold (float): The match threshold. + ransac_threshold (int): The RANSAC threshold. + ransac_confidence (float): The RANSAC confidence. + ransac_max_iter (int): The maximum number of RANSAC iterations. + scene_graph (bool): Whether to compute the scene graph. + global_feature (str): Whether to compute the global feature. + top_k (int): The number of image-pair to use. + mapper_refine_focal_length (bool): Whether to refine the focal length. + mapper_refine_principle_points (bool): Whether to refine the principle points. + mapper_refine_extra_params (bool): Whether to refine the extra parameters. + + Returns: + Path: The directory containing the SfM results. + """ + if len(images) == 0: + logger.error(f"{images} does not exist.") + + temp_images = Path(tempfile.mkdtemp()) + # copy images + logger.info(f"Copying images to {temp_images}.") + for image in images: + shutil.copy(image, temp_images) + + matcher_zoo = self.cfg["matcher_zoo"] + model = matcher_zoo[key] + match_conf = model["matcher"] + match_conf["model"]["max_keypoints"] = max_keypoints + match_conf["model"]["match_threshold"] = match_threshold + + feature_conf = model["feature"] + feature_conf["model"]["max_keypoints"] = max_keypoints + feature_conf["model"]["keypoint_threshold"] = keypoint_threshold + + # retrieval + retrieval_name = self.cfg.get("retrieval_name", "netvlad") + retrieval_conf = extract_features.confs[retrieval_name] + + mapper_options = { + "ba_refine_extra_params": mapper_refine_extra_params, + "ba_refine_focal_length": mapper_refine_focal_length, + "ba_refine_principal_point": mapper_refine_principle_points, + "ba_local_max_num_iterations": 40, + "ba_local_max_refinements": 3, + "ba_global_max_num_iterations": 100, + # below 3 options are for individual/video data, for internet photos, they should be left + # default + "min_focal_length_ratio": 0.1, + "max_focal_length_ratio": 10, + "max_extra_param": 1e15, + } + + sfm_dir = self.outputs / "sfm_{}".format(key) + sfm_pairs = self.outputs / "pairs-sfm.txt" + sfm_dir.mkdir(exist_ok=True, parents=True) + + # extract features + retrieval_path = extract_features.main( + retrieval_conf, temp_images, self.outputs + ) + pairs_from_retrieval.main(retrieval_path, sfm_pairs, num_matched=top_k) + + feature_path = extract_features.main(feature_conf, temp_images, self.outputs) + # match features + match_path = match_features.main( + match_conf, sfm_pairs, feature_conf["output"], self.outputs + ) + # reconstruction + already_sfm = False + if sfm_dir.exists(): + try: + model = pycolmap.Reconstruction(str(sfm_dir)) + already_sfm = True + except ValueError: + logger.info(f"sfm_dir not exists model: {sfm_dir}") + if not already_sfm: + model = reconstruction.main( + sfm_dir, + temp_images, + sfm_pairs, + feature_path, + match_path, + mapper_options=mapper_options, + ) + + vertices = [] + for point3D_id, point3D in model.points3D.items(): + vertices.append([point3D.xyz, point3D.color]) + + model_3d = sfm_dir / "points3D.obj" + with open(model_3d, "w") as f: + for p, c in vertices: + # Write vertex position + f.write("v {} {} {}\n".format(p[0], p[1], p[2])) + # Write vertex normal (color) + f.write( + "vn {} {} {}\n".format(c[0] / 255.0, c[1] / 255.0, c[2] / 255.0) + ) + viz_2d = visualization.visualize_sfm_2d( + model, temp_images, color_by="visibility", n=2, dpi=300 + ) + + return model_3d, fig2im(viz_2d) / 255.0 diff --git a/imcui/ui/utils.py b/imcui/ui/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b536ec496e821152e8b5570b0e38494fc55ddb --- /dev/null +++ b/imcui/ui/utils.py @@ -0,0 +1,1105 @@ +import os +import pickle +import random +import shutil +import time +import warnings +from itertools import combinations +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from datasets import load_dataset + +import cv2 +import gradio as gr +import matplotlib.pyplot as plt +import numpy as np +import poselib +import psutil +from PIL import Image + +from ..hloc import ( + DEVICE, + extract_features, + extractors, + logger, + match_dense, + match_features, + matchers, + DATASETS_REPO_ID, +) +from ..hloc.utils.base_model import dynamic_load +from .viz import display_keypoints, display_matches, fig2im, plot_images + +warnings.simplefilter("ignore") + +ROOT = Path(__file__).parents[1] +# some default values +DEFAULT_SETTING_THRESHOLD = 0.1 +DEFAULT_SETTING_MAX_FEATURES = 2000 +DEFAULT_DEFAULT_KEYPOINT_THRESHOLD = 0.01 +DEFAULT_ENABLE_RANSAC = True +DEFAULT_RANSAC_METHOD = "CV2_USAC_MAGSAC" +DEFAULT_RANSAC_REPROJ_THRESHOLD = 8 +DEFAULT_RANSAC_CONFIDENCE = 0.9999 +DEFAULT_RANSAC_MAX_ITER = 10000 +DEFAULT_MIN_NUM_MATCHES = 4 +DEFAULT_MATCHING_THRESHOLD = 0.2 +DEFAULT_SETTING_GEOMETRY = "Homography" +GRADIO_VERSION = gr.__version__.split(".")[0] +MATCHER_ZOO = None + + +class ModelCache: + def __init__(self, max_memory_size: int = 8): + self.max_memory_size = max_memory_size + self.current_memory_size = 0 + self.model_dict = {} + self.model_timestamps = [] + + def cache_model(self, model_key, model_loader_func, model_conf): + if model_key in self.model_dict: + self.model_timestamps.remove(model_key) + self.model_timestamps.append(model_key) + logger.info(f"Load cached {model_key}") + return self.model_dict[model_key] + + model = self._load_model_from_disk(model_loader_func, model_conf) + while self._calculate_model_memory() > self.max_memory_size: + if len(self.model_timestamps) == 0: + logger.warn( + "RAM: {}GB, MAX RAM: {}GB".format( + self._calculate_model_memory(), self.max_memory_size + ) + ) + break + oldest_model_key = self.model_timestamps.pop(0) + self.current_memory_size = self._calculate_model_memory() + logger.info(f"Del cached {oldest_model_key}") + del self.model_dict[oldest_model_key] + + self.model_dict[model_key] = model + self.model_timestamps.append(model_key) + + self.print_memory_usage() + logger.info(f"Total cached {list(self.model_dict.keys())}") + + return model + + def _load_model_from_disk(self, model_loader_func, model_conf): + return model_loader_func(model_conf) + + def _calculate_model_memory(self, verbose=False): + host_colocation = int(os.environ.get("HOST_COLOCATION", "1")) + vm = psutil.virtual_memory() + du = shutil.disk_usage(".") + if verbose: + logger.info( + f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB" + ) + logger.info( + f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB" + ) + return vm.used / 1e9 + + def print_memory_usage(self): + self._calculate_model_memory(verbose=True) + + +model_cache = ModelCache() + + +def load_config(config_name: str) -> Dict[str, Any]: + """ + Load a YAML configuration file. + + Args: + config_name: The path to the YAML configuration file. + + Returns: + The configuration dictionary, with string keys and arbitrary values. + """ + import yaml + + with open(config_name, "r") as stream: + try: + config: Dict[str, Any] = yaml.safe_load(stream) + except yaml.YAMLError as exc: + logger.error(exc) + return config + + +def get_matcher_zoo( + matcher_zoo: Dict[str, Dict[str, Union[str, bool]]], +) -> Dict[str, Dict[str, Union[Callable, bool]]]: + """ + Restore matcher configurations from a dictionary. + + Args: + matcher_zoo: A dictionary with the matcher configurations, + where the configuration is a dictionary as loaded from a YAML file. + + Returns: + A dictionary with the matcher configurations, where the configuration is + a function or a function instead of a string. + """ + matcher_zoo_restored = {} + for k, v in matcher_zoo.items(): + matcher_zoo_restored[k] = parse_match_config(v) + return matcher_zoo_restored + + +def parse_match_config(conf): + if conf["dense"]: + return { + "matcher": match_dense.confs.get(conf["matcher"]), + "dense": True, + } + else: + return { + "feature": extract_features.confs.get(conf["feature"]), + "matcher": match_features.confs.get(conf["matcher"]), + "dense": False, + } + + +def get_model(match_conf: Dict[str, Any]): + """ + Load a matcher model from the provided configuration. + + Args: + match_conf: A dictionary containing the model configuration. + + Returns: + A matcher model instance. + """ + Model = dynamic_load(matchers, match_conf["model"]["name"]) + model = Model(match_conf["model"]).eval().to(DEVICE) + return model + + +def get_feature_model(conf: Dict[str, Dict[str, Any]]): + """ + Load a feature extraction model from the provided configuration. + + Args: + conf: A dictionary containing the model configuration. + + Returns: + A feature extraction model instance. + """ + Model = dynamic_load(extractors, conf["model"]["name"]) + model = Model(conf["model"]).eval().to(DEVICE) + return model + + +def download_example_images(repo_id, output_dir): + logger.info(f"Download example dataset from huggingface: {repo_id}") + dataset = load_dataset(repo_id) + Path(output_dir).mkdir(parents=True, exist_ok=True) + for example in dataset["train"]: # Assuming the dataset is in the "train" split + file_path = example["path"] + image = example["image"] # Access the PIL.Image object directly + full_path = os.path.join(output_dir, file_path) + Path(os.path.dirname(full_path)).mkdir(parents=True, exist_ok=True) + image.save(full_path) + logger.info(f"Images saved to {output_dir} successfully.") + return Path(output_dir) + + +def gen_examples(data_root: Path): + random.seed(1) + example_matchers = [ + "disk+lightglue", + "xfeat(sparse)", + "dedode", + "loftr", + "disk", + "RoMa", + "d2net", + "aspanformer", + "topicfm", + "superpoint+superglue", + "superpoint+lightglue", + "superpoint+mnn", + "disk", + ] + data_root = Path(data_root) + if not Path(data_root).exists(): + try: + download_example_images(DATASETS_REPO_ID, data_root) + except Exception as e: + logger.error(f"download_example_images error : {e}") + data_root = ROOT / "datasets" + if not Path(data_root / "sacre_coeur/mapping").exists(): + download_example_images(DATASETS_REPO_ID, data_root) + + def distribute_elements(A, B): + new_B = np.array(B, copy=True).flatten() + np.random.shuffle(new_B) + new_B = np.resize(new_B, len(A)) + np.random.shuffle(new_B) + return new_B.tolist() + + # normal examples + def gen_images_pairs(count: int = 5): + path = str(data_root / "sacre_coeur/mapping") + imgs_list = [ + os.path.join(path, file) + for file in os.listdir(path) + if file.lower().endswith((".jpg", ".jpeg", ".png")) + ] + pairs = list(combinations(imgs_list, 2)) + if len(pairs) < count: + count = len(pairs) + selected = random.sample(range(len(pairs)), count) + return [pairs[i] for i in selected] + + # rotated examples + def gen_rot_image_pairs(count: int = 5): + path = data_root / "sacre_coeur/mapping" + path_rot = data_root / "sacre_coeur/mapping_rot" + rot_list = [45, 180, 90, 225, 270] + pairs = [] + for file in os.listdir(path): + if file.lower().endswith((".jpg", ".jpeg", ".png")): + for rot in rot_list: + file_rot = "{}_rot{}.jpg".format(Path(file).stem, rot) + if (path_rot / file_rot).exists(): + pairs.append( + [ + path / file, + path_rot / file_rot, + ] + ) + if len(pairs) < count: + count = len(pairs) + selected = random.sample(range(len(pairs)), count) + return [pairs[i] for i in selected] + + def gen_scale_image_pairs(count: int = 5): + path = data_root / "sacre_coeur/mapping" + path_scale = data_root / "sacre_coeur/mapping_scale" + scale_list = [0.3, 0.5] + pairs = [] + for file in os.listdir(path): + if file.lower().endswith((".jpg", ".jpeg", ".png")): + for scale in scale_list: + file_scale = "{}_scale{}.jpg".format(Path(file).stem, scale) + if (path_scale / file_scale).exists(): + pairs.append( + [ + path / file, + path_scale / file_scale, + ] + ) + if len(pairs) < count: + count = len(pairs) + selected = random.sample(range(len(pairs)), count) + return [pairs[i] for i in selected] + + # extramely hard examples + def gen_image_pairs_wxbs(count: int = None): + prefix = "wxbs_benchmark/.WxBS/v1.1" + wxbs_path = data_root / prefix + pairs = [] + for catg in os.listdir(wxbs_path): + catg_path = wxbs_path / catg + if not catg_path.is_dir(): + continue + for scene in os.listdir(catg_path): + scene_path = catg_path / scene + if not scene_path.is_dir(): + continue + img1_path = scene_path / "01.png" + img2_path = scene_path / "02.png" + if img1_path.exists() and img2_path.exists(): + pairs.append([str(img1_path), str(img2_path)]) + return pairs + + # image pair path + pairs = gen_images_pairs() + pairs += gen_rot_image_pairs() + pairs += gen_scale_image_pairs() + pairs += gen_image_pairs_wxbs() + + match_setting_threshold = DEFAULT_SETTING_THRESHOLD + match_setting_max_features = DEFAULT_SETTING_MAX_FEATURES + detect_keypoints_threshold = DEFAULT_DEFAULT_KEYPOINT_THRESHOLD + ransac_method = DEFAULT_RANSAC_METHOD + ransac_reproj_threshold = DEFAULT_RANSAC_REPROJ_THRESHOLD + ransac_confidence = DEFAULT_RANSAC_CONFIDENCE + ransac_max_iter = DEFAULT_RANSAC_MAX_ITER + input_lists = [] + dist_examples = distribute_elements(pairs, example_matchers) + for pair, mt in zip(pairs, dist_examples): + input_lists.append( + [ + pair[0], + pair[1], + match_setting_threshold, + match_setting_max_features, + detect_keypoints_threshold, + mt, + # enable_ransac, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + ] + ) + return input_lists + + +def set_null_pred(feature_type: str, pred: dict): + if feature_type == "KEYPOINT": + pred["mmkeypoints0_orig"] = np.array([]) + pred["mmkeypoints1_orig"] = np.array([]) + pred["mmconf"] = np.array([]) + elif feature_type == "LINE": + pred["mline_keypoints0_orig"] = np.array([]) + pred["mline_keypoints1_orig"] = np.array([]) + pred["H"] = None + pred["geom_info"] = {} + return pred + + +def _filter_matches_opencv( + kp0: np.ndarray, + kp1: np.ndarray, + method: int = cv2.RANSAC, + reproj_threshold: float = 3.0, + confidence: float = 0.99, + max_iter: int = 2000, + geometry_type: str = "Homography", +) -> Tuple[np.ndarray, np.ndarray]: + """ + Filters matches between two sets of keypoints using OpenCV's findHomography. + + Args: + kp0 (np.ndarray): Array of keypoints from the first image. + kp1 (np.ndarray): Array of keypoints from the second image. + method (int, optional): RANSAC method. Defaults to "cv2.RANSAC". + reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3.0. + confidence (float, optional): RANSAC confidence. Defaults to 0.99. + max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000. + geometry_type (str, optional): Type of geometry. Defaults to "Homography". + + Returns: + Tuple[np.ndarray, np.ndarray]: Homography matrix and mask. + """ + if geometry_type == "Homography": + try: + M, mask = cv2.findHomography( + kp0, + kp1, + method=method, + ransacReprojThreshold=reproj_threshold, + confidence=confidence, + maxIters=max_iter, + ) + except cv2.error: + logger.error("compute findHomography error, len(kp0): {}".format(len(kp0))) + return None, None + elif geometry_type == "Fundamental": + try: + M, mask = cv2.findFundamentalMat( + kp0, + kp1, + method=method, + ransacReprojThreshold=reproj_threshold, + confidence=confidence, + maxIters=max_iter, + ) + except cv2.error: + logger.error( + "compute findFundamentalMat error, len(kp0): {}".format(len(kp0)) + ) + return None, None + mask = np.array(mask.ravel().astype("bool"), dtype="bool") + return M, mask + + +def _filter_matches_poselib( + kp0: np.ndarray, + kp1: np.ndarray, + method: int = None, # not used + reproj_threshold: float = 3, + confidence: float = 0.99, + max_iter: int = 2000, + geometry_type: str = "Homography", +) -> dict: + """ + Filters matches between two sets of keypoints using the poselib library. + + Args: + kp0 (np.ndarray): Array of keypoints from the first image. + kp1 (np.ndarray): Array of keypoints from the second image. + method (str, optional): RANSAC method. Defaults to "RANSAC". + reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3. + confidence (float, optional): RANSAC confidence. Defaults to 0.99. + max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000. + geometry_type (str, optional): Type of geometry. Defaults to "Homography". + + Returns: + dict: Information about the homography estimation. + """ + ransac_options = { + "max_iterations": max_iter, + # "min_iterations": min_iter, + "success_prob": confidence, + "max_reproj_error": reproj_threshold, + # "progressive_sampling": args.sampler.lower() == 'prosac' + } + + if geometry_type == "Homography": + M, info = poselib.estimate_homography(kp0, kp1, ransac_options) + elif geometry_type == "Fundamental": + M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options) + else: + raise NotImplementedError + + return M, np.array(info["inliers"]) + + +def proc_ransac_matches( + mkpts0: np.ndarray, + mkpts1: np.ndarray, + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: float = 3.0, + ransac_confidence: float = 0.99, + ransac_max_iter: int = 2000, + geometry_type: str = "Homography", +): + if ransac_method.startswith("CV2"): + logger.info(f"ransac_method: {ransac_method}, geometry_type: {geometry_type}") + return _filter_matches_opencv( + mkpts0, + mkpts1, + ransac_zoo[ransac_method], + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type, + ) + elif ransac_method.startswith("POSELIB"): + logger.info(f"ransac_method: {ransac_method}, geometry_type: {geometry_type}") + return _filter_matches_poselib( + mkpts0, + mkpts1, + None, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type, + ) + else: + raise NotImplementedError + + +def filter_matches( + pred: Dict[str, Any], + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, + ransac_estimator: str = None, +): + """ + Filter matches using RANSAC. If keypoints are available, filter by keypoints. + If lines are available, filter by lines. If both keypoints and lines are + available, filter by keypoints. + + Args: + pred (Dict[str, Any]): dict of matches, including original keypoints. + ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD. + ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD. + ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE. + ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER. + + Returns: + Dict[str, Any]: filtered matches. + """ + mkpts0: Optional[np.ndarray] = None + mkpts1: Optional[np.ndarray] = None + feature_type: Optional[str] = None + if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys(): + mkpts0 = pred["mkeypoints0_orig"] + mkpts1 = pred["mkeypoints1_orig"] + feature_type = "KEYPOINT" + elif ( + "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys() + ): + mkpts0 = pred["line_keypoints0_orig"] + mkpts1 = pred["line_keypoints1_orig"] + feature_type = "LINE" + else: + return set_null_pred(feature_type, pred) + if mkpts0 is None or mkpts0 is None: + return set_null_pred(feature_type, pred) + if ransac_method not in ransac_zoo.keys(): + ransac_method = DEFAULT_RANSAC_METHOD + + if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES: + return set_null_pred(feature_type, pred) + + geom_info = compute_geometry( + pred, + ransac_method=ransac_method, + ransac_reproj_threshold=ransac_reproj_threshold, + ransac_confidence=ransac_confidence, + ransac_max_iter=ransac_max_iter, + ) + + if "Homography" in geom_info.keys(): + mask = geom_info["mask_h"] + if feature_type == "KEYPOINT": + pred["mmkeypoints0_orig"] = mkpts0[mask] + pred["mmkeypoints1_orig"] = mkpts1[mask] + pred["mmconf"] = pred["mconf"][mask] + elif feature_type == "LINE": + pred["mline_keypoints0_orig"] = mkpts0[mask] + pred["mline_keypoints1_orig"] = mkpts1[mask] + pred["H"] = np.array(geom_info["Homography"]) + else: + set_null_pred(feature_type, pred) + # do not show mask + geom_info.pop("mask_h", None) + geom_info.pop("mask_f", None) + pred["geom_info"] = geom_info + return pred + + +def compute_geometry( + pred: Dict[str, Any], + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, +) -> Dict[str, List[float]]: + """ + Compute geometric information of matches, including Fundamental matrix, + Homography matrix, and rectification matrices (if available). + + Args: + pred (Dict[str, Any]): dict of matches, including original keypoints. + ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD. + ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD. + ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE. + ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER. + + Returns: + Dict[str, List[float]]: geometric information in form of a dict. + """ + mkpts0: Optional[np.ndarray] = None + mkpts1: Optional[np.ndarray] = None + + if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys(): + mkpts0 = pred["mkeypoints0_orig"] + mkpts1 = pred["mkeypoints1_orig"] + elif ( + "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys() + ): + mkpts0 = pred["line_keypoints0_orig"] + mkpts1 = pred["line_keypoints1_orig"] + + if mkpts0 is not None and mkpts1 is not None: + if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES: + return {} + geo_info: Dict[str, List[float]] = {} + + F, mask_f = proc_ransac_matches( + mkpts0, + mkpts1, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type="Fundamental", + ) + + if F is not None: + geo_info["Fundamental"] = F.tolist() + geo_info["mask_f"] = mask_f + H, mask_h = proc_ransac_matches( + mkpts1, + mkpts0, + ransac_method, + ransac_reproj_threshold, + ransac_confidence, + ransac_max_iter, + geometry_type="Homography", + ) + + h0, w0, _ = pred["image0_orig"].shape + if H is not None: + geo_info["Homography"] = H.tolist() + geo_info["mask_h"] = mask_h + try: + _, H1, H2 = cv2.stereoRectifyUncalibrated( + mkpts0.reshape(-1, 2), + mkpts1.reshape(-1, 2), + F, + imgSize=(w0, h0), + ) + geo_info["H1"] = H1.tolist() + geo_info["H2"] = H2.tolist() + except cv2.error as e: + logger.error(f"StereoRectifyUncalibrated failed, skip! error: {e}") + return geo_info + else: + return {} + + +def wrap_images( + img0: np.ndarray, + img1: np.ndarray, + geo_info: Optional[Dict[str, List[float]]], + geom_type: str, +) -> Tuple[Optional[str], Optional[Dict[str, List[float]]]]: + """ + Wraps the images based on the geometric transformation used to align them. + + Args: + img0: numpy array representing the first image. + img1: numpy array representing the second image. + geo_info: dictionary containing the geometric transformation information. + geom_type: type of geometric transformation used to align the images. + + Returns: + A tuple containing a base64 encoded image string and a dictionary with the transformation matrix. + """ + h0, w0, _ = img0.shape + h1, w1, _ = img1.shape + if geo_info is not None and len(geo_info) != 0: + rectified_image0 = img0 + rectified_image1 = None + if "Homography" not in geo_info: + logger.warning(f"{geom_type} not exist, maybe too less matches") + return None, None + + H = np.array(geo_info["Homography"]) + + title: List[str] = [] + if geom_type == "Homography": + rectified_image1 = cv2.warpPerspective(img1, H, (w0, h0)) + title = ["Image 0", "Image 1 - warped"] + elif geom_type == "Fundamental": + if geom_type not in geo_info: + logger.warning(f"{geom_type} not exist, maybe too less matches") + return None, None + else: + H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"]) + rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0)) + rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1)) + title = ["Image 0 - warped", "Image 1 - warped"] + else: + print("Error: Unknown geometry type") + fig = plot_images( + [rectified_image0.squeeze(), rectified_image1.squeeze()], + title, + dpi=300, + ) + return fig2im(fig), rectified_image1 + else: + return None, None + + +def generate_warp_images( + input_image0: np.ndarray, + input_image1: np.ndarray, + matches_info: Dict[str, Any], + choice: str, +) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """ + Changes the estimate of the geometric transformation used to align the images. + + Args: + input_image0: First input image. + input_image1: Second input image. + matches_info: Dictionary containing information about the matches. + choice: Type of geometric transformation to use ('Homography' or 'Fundamental') or 'No' to disable. + + Returns: + A tuple containing the updated images and the warpped images. + """ + if ( + matches_info is None + or len(matches_info) < 1 + or "geom_info" not in matches_info.keys() + ): + return None, None + geom_info = matches_info["geom_info"] + warped_image = None + if choice != "No": + wrapped_image_pair, warped_image = wrap_images( + input_image0, input_image1, geom_info, choice + ) + return wrapped_image_pair, warped_image + else: + return None, None + + +def send_to_match(state_cache: Dict[str, Any]): + """ + Send the state cache to the match function. + + Args: + state_cache (Dict[str, Any]): Current state of the app. + + Returns: + None + """ + if state_cache: + return ( + state_cache["image0_orig"], + state_cache["wrapped_image"], + ) + else: + return None, None + + +def run_ransac( + state_cache: Dict[str, Any], + choice_geometry_type: str, + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, +) -> Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]: + """ + Run RANSAC matches and return the output images and the number of matches. + + Args: + state_cache (Dict[str, Any]): Current state of the app, including the matches. + ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD. + ransac_reproj_threshold (int, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD. + ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE. + ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER. + + Returns: + Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]: Tuple containing the output images and the number of matches. + """ + if not state_cache: + logger.info("Run Match first before Rerun RANSAC") + gr.Warning("Run Match first before Rerun RANSAC") + return None, None + t1 = time.time() + logger.info( + f"Run RANSAC matches using: {ransac_method} with threshold: {ransac_reproj_threshold}" + ) + logger.info( + f"Run RANSAC matches using: {ransac_confidence} with iter: {ransac_max_iter}" + ) + # if enable_ransac: + filter_matches( + state_cache, + ransac_method=ransac_method, + ransac_reproj_threshold=ransac_reproj_threshold, + ransac_confidence=ransac_confidence, + ransac_max_iter=ransac_max_iter, + ) + logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # plot images with ransac matches + titles = [ + "Image 0 - Ransac matched keypoints", + "Image 1 - Ransac matched keypoints", + ] + output_matches_ransac, num_matches_ransac = display_matches( + state_cache, titles=titles, tag="KPTS_RANSAC" + ) + logger.info(f"Display matches done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # compute warp images + output_wrapped, warped_image = generate_warp_images( + state_cache["image0_orig"], + state_cache["image1_orig"], + state_cache, + choice_geometry_type, + ) + plt.close("all") + + num_matches_raw = state_cache["num_matches_raw"] + state_cache["wrapped_image"] = warped_image + + # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) + tmp_state_cache = "output.pkl" + with open(tmp_state_cache, "wb") as f: + pickle.dump(state_cache, f) + + logger.info("Dump results done!") + + return ( + output_matches_ransac, + { + "num_matches_raw": num_matches_raw, + "num_matches_ransac": num_matches_ransac, + }, + output_wrapped, + tmp_state_cache, + ) + + +def run_matching( + image0: np.ndarray, + image1: np.ndarray, + match_threshold: float, + extract_max_keypoints: int, + keypoint_threshold: float, + key: str, + ransac_method: str = DEFAULT_RANSAC_METHOD, + ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD, + ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE, + ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER, + choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY, + matcher_zoo: Dict[str, Any] = None, + force_resize: bool = False, + image_width: int = 640, + image_height: int = 480, + use_cached_model: bool = False, +) -> Tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + Dict[str, int], + Dict[str, Dict[str, Any]], + Dict[str, Dict[str, float]], + np.ndarray, +]: + """Match two images using the given parameters. + + Args: + image0 (np.ndarray): RGB image 0. + image1 (np.ndarray): RGB image 1. + match_threshold (float): match threshold. + extract_max_keypoints (int): number of keypoints to extract. + keypoint_threshold (float): keypoint threshold. + key (str): key of the model to use. + ransac_method (str, optional): RANSAC method to use. + ransac_reproj_threshold (int, optional): RANSAC reprojection threshold. + ransac_confidence (float, optional): RANSAC confidence level. + ransac_max_iter (int, optional): RANSAC maximum number of iterations. + choice_geometry_type (str, optional): setting of geometry estimation. + matcher_zoo (Dict[str, Any], optional): matcher zoo. Defaults to None. + force_resize (bool, optional): force resize. Defaults to False. + image_width (int, optional): image width. Defaults to 640. + image_height (int, optional): image height. Defaults to 480. + use_cached_model (bool, optional): use cached model. Defaults to False. + + Returns: + tuple: + - output_keypoints (np.ndarray): image with keypoints. + - output_matches_raw (np.ndarray): image with raw matches. + - output_matches_ransac (np.ndarray): image with RANSAC matches. + - num_matches (Dict[str, int]): number of raw and RANSAC matches. + - configs (Dict[str, Dict[str, Any]]): match and feature extraction configs. + - geom_info (Dict[str, Dict[str, float]]): geometry information. + - output_wrapped (np.ndarray): wrapped images. + """ + # image0 and image1 is RGB mode + if image0 is None or image1 is None: + logger.error( + "Error: No images found! Please upload two images or select an example." + ) + raise gr.Error( + "Error: No images found! Please upload two images or select an example." + ) + # init output + output_keypoints = None + output_matches_raw = None + output_matches_ransac = None + + # super slow! + if "roma" in key.lower() and DEVICE == "cpu": + gr.Info( + f"Success! Please be patient and allow for about 2-3 minutes." + f" Due to CPU inference, {key} is quiet slow." + ) + t0 = time.time() + model = matcher_zoo[key] + match_conf = model["matcher"] + # update match config + match_conf["model"]["match_threshold"] = match_threshold + match_conf["model"]["max_keypoints"] = extract_max_keypoints + cache_key = "{}_{}".format(key, match_conf["model"]["name"]) + if use_cached_model: + # because of the model cache, we need to update the config + matcher = model_cache.cache_model(cache_key, get_model, match_conf) + matcher.conf["max_keypoints"] = extract_max_keypoints + matcher.conf["match_threshold"] = match_threshold + logger.info(f"Loaded cached model {cache_key}") + else: + matcher = get_model(match_conf) + logger.info(f"Loading model using: {time.time()-t0:.3f}s") + t1 = time.time() + + if model["dense"]: + if not match_conf["preprocessing"].get("force_resize", False): + match_conf["preprocessing"]["force_resize"] = force_resize + else: + logger.info("preprocessing is already resized") + if force_resize: + match_conf["preprocessing"]["height"] = image_height + match_conf["preprocessing"]["width"] = image_width + logger.info(f"Force resize to {image_width}x{image_height}") + + pred = match_dense.match_images( + matcher, image0, image1, match_conf["preprocessing"], device=DEVICE + ) + del matcher + extract_conf = None + else: + extract_conf = model["feature"] + # update extract config + extract_conf["model"]["max_keypoints"] = extract_max_keypoints + extract_conf["model"]["keypoint_threshold"] = keypoint_threshold + cache_key = "{}_{}".format(key, extract_conf["model"]["name"]) + + if use_cached_model: + extractor = model_cache.cache_model( + cache_key, get_feature_model, extract_conf + ) + # because of the model cache, we need to update the config + extractor.conf["max_keypoints"] = extract_max_keypoints + extractor.conf["keypoint_threshold"] = keypoint_threshold + logger.info(f"Loaded cached model {cache_key}") + else: + extractor = get_feature_model(extract_conf) + + if not extract_conf["preprocessing"].get("force_resize", False): + extract_conf["preprocessing"]["force_resize"] = force_resize + else: + logger.info("preprocessing is already resized") + if force_resize: + extract_conf["preprocessing"]["height"] = image_height + extract_conf["preprocessing"]["width"] = image_width + logger.info(f"Force resize to {image_width}x{image_height}") + + pred0 = extract_features.extract( + extractor, image0, extract_conf["preprocessing"] + ) + pred1 = extract_features.extract( + extractor, image1, extract_conf["preprocessing"] + ) + pred = match_features.match_images(matcher, pred0, pred1) + del extractor + # gr.Info( + # f"Matching images done using: {time.time()-t1:.3f}s", + # ) + logger.info(f"Matching images done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # plot images with keypoints + titles = [ + "Image 0 - Keypoints", + "Image 1 - Keypoints", + ] + output_keypoints = display_keypoints(pred, titles=titles) + + # plot images with raw matches + titles = [ + "Image 0 - Raw matched keypoints", + "Image 1 - Raw matched keypoints", + ] + output_matches_raw, num_matches_raw = display_matches(pred, titles=titles) + + # if enable_ransac: + filter_matches( + pred, + ransac_method=ransac_method, + ransac_reproj_threshold=ransac_reproj_threshold, + ransac_confidence=ransac_confidence, + ransac_max_iter=ransac_max_iter, + ) + + # gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s") + logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s") + t1 = time.time() + + # plot images with ransac matches + titles = [ + "Image 0 - Ransac matched keypoints", + "Image 1 - Ransac matched keypoints", + ] + output_matches_ransac, num_matches_ransac = display_matches( + pred, titles=titles, tag="KPTS_RANSAC" + ) + # gr.Info(f"Display matches done using: {time.time()-t1:.3f}s") + logger.info(f"Display matches done using: {time.time()-t1:.3f}s") + + t1 = time.time() + # plot wrapped images + output_wrapped, warped_image = generate_warp_images( + pred["image0_orig"], + pred["image1_orig"], + pred, + choice_geometry_type, + ) + plt.close("all") + # gr.Info(f"In summary, total time: {time.time()-t0:.3f}s") + logger.info(f"TOTAL time: {time.time()-t0:.3f}s") + + state_cache = pred + state_cache["num_matches_raw"] = num_matches_raw + state_cache["num_matches_ransac"] = num_matches_ransac + state_cache["wrapped_image"] = warped_image + + # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) + tmp_state_cache = "output.pkl" + with open(tmp_state_cache, "wb") as f: + pickle.dump(state_cache, f) + logger.info("Dump results done!") + return ( + output_keypoints, + output_matches_raw, + output_matches_ransac, + { + "num_raw_matches": num_matches_raw, + "num_ransac_matches": num_matches_ransac, + }, + { + "match_conf": match_conf, + "extractor_conf": extract_conf, + }, + { + "geom_info": pred.get("geom_info", {}), + }, + output_wrapped, + state_cache, + tmp_state_cache, + ) + + +# @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html +# AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs +ransac_zoo = { + "POSELIB": "LO-RANSAC", + "CV2_RANSAC": cv2.RANSAC, + "CV2_USAC_MAGSAC": cv2.USAC_MAGSAC, + "CV2_USAC_DEFAULT": cv2.USAC_DEFAULT, + "CV2_USAC_FM_8PTS": cv2.USAC_FM_8PTS, + "CV2_USAC_PROSAC": cv2.USAC_PROSAC, + "CV2_USAC_FAST": cv2.USAC_FAST, + "CV2_USAC_ACCURATE": cv2.USAC_ACCURATE, + "CV2_USAC_PARALLEL": cv2.USAC_PARALLEL, +} + + +def rotate_image(input_path, degrees, output_path): + img = Image.open(input_path) + img_rotated = img.rotate(-degrees) + img_rotated.save(output_path) + + +def scale_image(input_path, scale_factor, output_path): + img = Image.open(input_path) + width, height = img.size + new_width = int(width * scale_factor) + new_height = int(height * scale_factor) + new_img = Image.new("RGB", (width, height), (0, 0, 0)) + img_resized = img.resize((new_width, new_height)) + position = ((width - new_width) // 2, (height - new_height) // 2) + new_img.paste(img_resized, position) + new_img.save(output_path) diff --git a/imcui/ui/viz.py b/imcui/ui/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c29474de2c3e530038f5be923f5a088a79c80c --- /dev/null +++ b/imcui/ui/viz.py @@ -0,0 +1,482 @@ +import typing +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + +from ..hloc.utils.viz import add_text, plot_keypoints + +np.random.seed(1995) +color_map = np.arange(100) +np.random.shuffle(color_map) + + +def plot_images( + imgs: List[np.ndarray], + titles: Optional[List[str]] = None, + cmaps: Union[str, List[str]] = "gray", + dpi: int = 100, + size: Optional[int] = 5, + pad: float = 0.5, +) -> plt.Figure: + """Plot a set of images horizontally. + Args: + imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). + titles: a list of strings, as titles for each image. + cmaps: colormaps for monochrome images. If a single string is given, + it is used for all images. + dpi: DPI of the figure. + size: figure size in inches (width). If not provided, the figure + size is determined automatically. + pad: padding between subplots, in inches. + Returns: + The created figure. + """ + n = len(imgs) + if not isinstance(cmaps, list): + cmaps = [cmaps] * n + figsize = (size * n, size * 6 / 5) if size is not None else None + fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) + + if n == 1: + ax = [ax] + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + ax[i].set_axis_off() + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + if titles: + ax[i].set_title(titles[i]) + fig.tight_layout(pad=pad) + return fig + + +def plot_color_line_matches( + lines: List[np.ndarray], + correct_matches: Optional[np.ndarray] = None, + lw: float = 2.0, + indices: Tuple[int, int] = (0, 1), +) -> matplotlib.figure.Figure: + """Plot line matches for existing images with multiple colors. + + Args: + lines: List of ndarrays of size (N, 2, 2) representing line segments. + correct_matches: Optional bool array of size (N,) indicating correct + matches. If not None, display wrong matches with a low alpha. + lw: Line width as float pixels. + indices: Indices of the images to draw the matches on. + + Returns: + The modified matplotlib figure. + """ + n_lines = lines[0].shape[0] + colors = sns.color_palette("husl", n_colors=n_lines) + np.random.shuffle(colors) + alphas = np.ones(n_lines) + if correct_matches is not None: + alphas[~np.array(correct_matches)] = 0.2 + + fig = plt.gcf() + ax = typing.cast(List[matplotlib.axes.Axes], fig.axes) + assert len(ax) > max(indices) + axes = [ax[i] for i in indices] + fig.canvas.draw() + + # Plot the lines + for a, l in zip(axes, lines): # noqa: E741 + # Transform the points into the figure coordinate system + transFigure = fig.transFigure.inverted() + endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) + endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) + fig.lines += [ + matplotlib.lines.Line2D( + (endpoint0[i, 0], endpoint1[i, 0]), + (endpoint0[i, 1], endpoint1[i, 1]), + zorder=1, + transform=fig.transFigure, + c=colors[i], + alpha=alphas[i], + linewidth=lw, + ) + for i in range(n_lines) + ] + + return fig + + +def make_matching_figure( + img0: np.ndarray, + img1: np.ndarray, + mkpts0: np.ndarray, + mkpts1: np.ndarray, + color: np.ndarray, + titles: Optional[List[str]] = None, + kpts0: Optional[np.ndarray] = None, + kpts1: Optional[np.ndarray] = None, + text: List[str] = [], + dpi: int = 75, + path: Optional[Path] = None, + pad: float = 0.0, +) -> Optional[plt.Figure]: + """Draw image pair with matches. + + Args: + img0: image0 as HxWx3 numpy array. + img1: image1 as HxWx3 numpy array. + mkpts0: matched points in image0 as Nx2 numpy array. + mkpts1: matched points in image1 as Nx2 numpy array. + color: colors for the matches as Nx4 numpy array. + titles: titles for the two subplots. + kpts0: keypoints in image0 as Kx2 numpy array. + kpts1: keypoints in image1 as Kx2 numpy array. + text: list of strings to display in the top-left corner of the image. + dpi: dots per inch of the saved figure. + path: if not None, save the figure to this path. + pad: padding around the image as a fraction of the image size. + + Returns: + The matplotlib Figure object if path is None. + """ + # draw image pair + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0) # , cmap='gray') + axes[1].imshow(img1) # , cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + if titles is not None: + axes[i].set_title(titles[i]) + + plt.tight_layout(pad=pad) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0 and mkpts0.shape == mkpts1.shape: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color[i], + linewidth=2, + ) + for i in range(len(mkpts0)) + ] + + # freeze the axes to prevent the transform to change + axes[0].autoscale(enable=False) + axes[1].autoscale(enable=False) + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4) + + # put txts + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + fig.text( + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig + + +def error_colormap(err: np.ndarray, thr: float, alpha: float = 1.0) -> np.ndarray: + """ + Create a colormap based on the error values. + + Args: + err: Error values as a numpy array of shape (N,). + thr: Threshold value for the error. + alpha: Alpha value for the colormap, between 0 and 1. + + Returns: + Colormap as a numpy array of shape (N, 4) with values in [0, 1]. + """ + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1), + 0, + 1, + ) + + +def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray: + """ + Convert a matplotlib figure to a numpy array with RGB values. + + Args: + fig: A matplotlib figure. + + Returns: + A numpy array with shape (height, width, 3) and dtype uint8 containing + the RGB values of the figure. + """ + fig.canvas.draw() + (width, height) = fig.canvas.get_width_height() + buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1") + return buf_ndarray.reshape(height, width, 3) + + +def draw_matches_core( + mkpts0: List[np.ndarray], + mkpts1: List[np.ndarray], + img0: np.ndarray, + img1: np.ndarray, + conf: np.ndarray, + titles: Optional[List[str]] = None, + texts: Optional[List[str]] = None, + dpi: int = 150, + path: Optional[str] = None, + pad: float = 0.5, +) -> np.ndarray: + """ + Draw matches between two images. + + Args: + mkpts0: List of matches from the first image, with shape (N, 2) + mkpts1: List of matches from the second image, with shape (N, 2) + img0: First image, with shape (H, W, 3) + img1: Second image, with shape (H, W, 3) + conf: Confidence values for the matches, with shape (N,) + titles: Optional list of title strings for the plot + dpi: DPI for the saved image + path: Optional path to save the image to. If None, the image is not saved. + pad: Padding between subplots + + Returns: + The figure as a numpy array with shape (height, width, 3) and dtype uint8 + containing the RGB values of the figure. + """ + thr = 0.5 + color = error_colormap(1 - conf, thr, alpha=0.1) + text = [ + # "image name", + f"#Matches: {len(mkpts0)}", + ] + if path: + fig2im( + make_matching_figure( + img0, + img1, + mkpts0, + mkpts1, + color, + titles=titles, + text=text, + path=path, + dpi=dpi, + pad=pad, + ) + ) + else: + return fig2im( + make_matching_figure( + img0, + img1, + mkpts0, + mkpts1, + color, + titles=titles, + text=text, + pad=pad, + dpi=dpi, + ) + ) + + +def draw_image_pairs( + img0: np.ndarray, + img1: np.ndarray, + text: List[str] = [], + dpi: int = 75, + path: Optional[str] = None, + pad: float = 0.5, +) -> np.ndarray: + """Draw image pair horizontally. + + Args: + img0: First image, with shape (H, W, 3) + img1: Second image, with shape (H, W, 3) + text: List of strings to print. Each string is a new line. + dpi: DPI of the figure. + path: Path to save the image to. If None, the image is not saved and + the function returns the figure as a numpy array with shape + (height, width, 3) and dtype uint8 containing the RGB values of the + figure. + pad: Padding between subplots + + Returns: + The figure as a numpy array with shape (height, width, 3) and dtype uint8 + containing the RGB values of the figure, or None if path is not None. + """ + # draw image pair + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0) # , cmap='gray') + axes[1].imshow(img1) # , cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=pad) + + # put txts + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + fig.text( + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig2im(fig) + + +def display_keypoints(pred: dict, titles: List[str] = []): + img0 = pred["image0_orig"] + img1 = pred["image1_orig"] + output_keypoints = plot_images([img0, img1], titles=titles, dpi=300) + if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys(): + plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]]) + text = ( + f"# keypoints0: {len(pred['keypoints0_orig'])} \n" + + f"# keypoints1: {len(pred['keypoints1_orig'])}" + ) + add_text(0, text, fs=15) + output_keypoints = fig2im(output_keypoints) + return output_keypoints + + +def display_matches( + pred: Dict[str, np.ndarray], + titles: List[str] = [], + texts: List[str] = [], + dpi: int = 300, + tag: str = "KPTS_RAW", # KPTS_RAW, KPTS_RANSAC, LINES_RAW, LINES_RANSAC, +) -> Tuple[np.ndarray, int]: + """ + Displays the matches between two images. + + Args: + pred: Dictionary containing the original images and the matches. + titles: Optional titles for the plot. + dpi: Resolution of the plot. + + Returns: + The resulting concatenated plot and the number of inliers. + """ + img0 = pred["image0_orig"] + img1 = pred["image1_orig"] + num_inliers = 0 + KPTS0_KEY = None + KPTS1_KEY = None + confid = None + if tag == "KPTS_RAW": + KPTS0_KEY = "mkeypoints0_orig" + KPTS1_KEY = "mkeypoints1_orig" + if "mconf" in pred: + confid = pred["mconf"] + elif tag == "KPTS_RANSAC": + KPTS0_KEY = "mmkeypoints0_orig" + KPTS1_KEY = "mmkeypoints1_orig" + if "mmconf" in pred: + confid = pred["mmconf"] + else: + # TODO: LINES_RAW, LINES_RANSAC + raise ValueError(f"Unknown tag: {tag}") + # draw raw matches + if ( + KPTS0_KEY in pred + and KPTS1_KEY in pred + and pred[KPTS0_KEY] is not None + and pred[KPTS1_KEY] is not None + ): # draw ransac matches + mkpts0 = pred[KPTS0_KEY] + mkpts1 = pred[KPTS1_KEY] + num_inliers = len(mkpts0) + if confid is None: + confid = np.ones(len(mkpts0)) + fig_mkpts = draw_matches_core( + mkpts0, + mkpts1, + img0, + img1, + confid, + dpi=dpi, + titles=titles, + texts=texts, + ) + fig = fig_mkpts + # TODO: draw lines + if ( + "line0_orig" in pred + and "line1_orig" in pred + and pred["line0_orig"] is not None + and pred["line1_orig"] is not None + and (tag == "LINES_RAW" or tag == "LINES_RANSAC") + ): + # lines + mtlines0 = pred["line0_orig"] + mtlines1 = pred["line1_orig"] + num_inliers = len(mtlines0) + fig_lines = plot_images( + [img0.squeeze(), img1.squeeze()], + ["Image 0 - matched lines", "Image 1 - matched lines"], + dpi=300, + ) + fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2) + fig_lines = fig2im(fig_lines) + + # keypoints + mkpts0 = pred.get("line_keypoints0_orig") + mkpts1 = pred.get("line_keypoints1_orig") + fig = None + if mkpts0 is not None and mkpts1 is not None: + num_inliers = len(mkpts0) + if "mconf" in pred: + mconf = pred["mconf"] + else: + mconf = np.ones(len(mkpts0)) + fig_mkpts = draw_matches_core(mkpts0, mkpts1, img0, img1, mconf, dpi=300) + fig_lines = cv2.resize(fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])) + fig = np.concatenate([fig_mkpts, fig_lines], axis=0) + else: + fig = fig_lines + return fig, num_inliers diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..e41546dd40f64eb925c05c4b2d74bcfd674b52e5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "imcui" +description = "Image Matching Webui: A tool for matching images using sota algorithms with a Gradio UI" +version = "0.0.1" +authors = [ + {name = "vincentqyw"}, +] +readme = "README.md" +requires-python = ">=3.9" +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +urls = {Repository = "https://github.com/Vincentqyw/image-matching-webui"} +dynamic = ["dependencies"] + + +[project.optional-dependencies] +dev = ["black", "flake8", "isort"] + + +[tool.setuptools] +packages = { find = { include = ["imcui*"] } } +include-package-data = true + + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +xfail_strict = true +testpaths = ["tests"] diff --git a/railway.toml b/railway.toml new file mode 100644 index 0000000000000000000000000000000000000000..58accec161cc235ab3a2e1adcc8e2376e9470b56 --- /dev/null +++ b/railway.toml @@ -0,0 +1,11 @@ +[build] +builder = "DOCKERFILE" +dockerfilePath = "Dockerfile" + +[deploy] +runtime = "V2" +numReplicas = 1 +startCommand = "python -m imcui.api.server" +sleepApplication = false +restartPolicyType = "ON_FAILURE" +restartPolicyMaxRetries = 10 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ac89a384cb11cf10942de68ebf716cb83235e2f4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,42 @@ +datasets +e2cnn +easydict +einops +fastapi +gdown +gradio<=5.4.0 +h5py +huggingface_hub +imageio +Jinja2 +kornia +loguru +matplotlib<3.9 +numpy~=1.26 +omegaconf +onnxruntime +opencv-contrib-python +opencv-python +pandas +plotly +poselib +protobuf +psutil +pycolmap==0.6.1 +pytlsd +pytorch-lightning==1.4.9 +PyYAML +ray +ray[serve] +roma #dust3r +scikit-image +scikit-learn +scipy +seaborn +shapely +tensorboardX==2.6.1 +torchmetrics==0.6.0 +torchvision==0.19.0 +tqdm +uvicorn +yacs diff --git a/tests/data/02928139_3448003521.jpg b/tests/data/02928139_3448003521.jpg new file mode 100644 index 0000000000000000000000000000000000000000..102589fa1a501f365fef0051f5ae97c42eb560ff --- /dev/null +++ b/tests/data/02928139_3448003521.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f52d9dcdb3ba9d8cf025025fb1be3f8f8d1ba0e0d84ab7eeb271215589ca608 +size 518060 diff --git a/tests/data/17295357_9106075285.jpg b/tests/data/17295357_9106075285.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3d38e80b2a28c7d06b28cc9a36b97d656b60b912 --- /dev/null +++ b/tests/data/17295357_9106075285.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54dff1885bf44b5c0e0c0ce702220832e99e5b30f38462d1ef5b9d4a0d794f98 +size 535133 diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000000000000000000000000000000000000..d35797f0ea475df8346db1986c7d758d5a24f032 --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,111 @@ +import cv2 +from pathlib import Path +from imcui.hloc import logger +from imcui.ui.utils import DEVICE, get_matcher_zoo, load_config +from imcui.api import ImageMatchingAPI + +ROOT = Path(__file__).parents[1] + + +def test_all(): + config = load_config(ROOT / "config/config.yaml") + img_path1 = ROOT / "tests/data/02928139_3448003521.jpg" + img_path2 = ROOT / "tests/data/17295357_9106075285.jpg" + image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB + image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB + + matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"]) + for k, v in matcher_zoo_restored.items(): + if image0 is None or image1 is None: + logger.error("Error: No images found! Please upload two images.") + enable = config["matcher_zoo"][k].get("enable", True) + skip_ci = config["matcher_zoo"][k].get("skip_ci", False) + if enable and not skip_ci: + logger.info(f"Testing {k} ...") + api = ImageMatchingAPI(conf=v, device=DEVICE) + pred = api(image0, image1) + assert pred is not None + log_path = ROOT / "experiments" / "all" + log_path.mkdir(exist_ok=True, parents=True) + api.visualize(log_path=log_path) + else: + logger.info(f"Skipping {k} ...") + + +def test_one(): + img_path1 = ROOT / "tests/data/02928139_3448003521.jpg" + img_path2 = ROOT / "tests/data/17295357_9106075285.jpg" + + image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB + image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB + # sparse + conf = { + "feature": { + "output": "feats-superpoint-n4096-rmax1600", + "model": { + "name": "superpoint", + "nms_radius": 3, + "max_keypoints": 4096, + "keypoint_threshold": 0.005, + }, + "preprocessing": { + "grayscale": True, + "force_resize": True, + "resize_max": 1600, + "width": 640, + "height": 480, + "dfactor": 8, + }, + }, + "matcher": { + "output": "matches-NN-mutual", + "model": { + "name": "nearest_neighbor", + "do_mutual_check": True, + "match_threshold": 0.2, + }, + }, + "dense": False, + } + api = ImageMatchingAPI(conf=conf, device=DEVICE) + pred = api(image0, image1) + assert pred is not None + log_path = ROOT / "experiments" / "one" + log_path.mkdir(exist_ok=True, parents=True) + api.visualize(log_path=log_path) + + # dense + conf = { + "matcher": { + "output": "matches-loftr", + "model": { + "name": "loftr", + "weights": "outdoor", + "max_keypoints": 2000, + "match_threshold": 0.2, + }, + "preprocessing": { + "grayscale": True, + "resize_max": 1024, + "dfactor": 8, + "width": 640, + "height": 480, + "force_resize": True, + }, + "max_error": 1, + "cell_size": 1, + }, + "dense": True, + } + + api = ImageMatchingAPI(conf=conf, device=DEVICE) + pred = api(image0, image1) + assert pred is not None + log_path = ROOT / "experiments" / "one" + log_path.mkdir(exist_ok=True, parents=True) + api.visualize(log_path=log_path) + + +if __name__ == "__main__": + test_one() + test_all() diff --git a/vercel.json b/vercel.json new file mode 100644 index 0000000000000000000000000000000000000000..27e0b209a1eaaf329b607bf3cfd43a7daced702e --- /dev/null +++ b/vercel.json @@ -0,0 +1,18 @@ +{ + "builds": [ + { + "src": "api/server.py", + "use": "@vercel/python", + "config": { + "maxLambdaSize": "10gb", + "runtime": "python3.10" + } + } + ], + "routes": [ + { + "src": "/(.*)", + "dest": "api/server.py" + } + ] +}